Compare commits

...

51 Commits

Author SHA1 Message Date
d01207dbf3 Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache.

* Add some KvCache tests.

* Test the reset too.

* More kv-cache testing.

* More tests for the rotating kv-cache.

* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.

* Handle contiguity + bugfix + use in mimi.

* Add a way to test the mimi streaming mode.

* Mimi streaming fixes.

* More rotating kv-cache.

* Fix the attn mask generation.

* Handle the abs case.

* Add some tests for the generated mask.
2024-09-23 13:14:32 +02:00
8097559c1a Move the candle version to 0.7.1. (#2495) 2024-09-22 20:44:39 +02:00
829dcfa8dc Update cudarc to 0.12.1. (#2494) 2024-09-22 20:32:29 +02:00
c2fca0ca11 Bump the crate version. (#2491) 2024-09-21 15:13:12 +02:00
844d45cde4 Bugfix for the metal elu kernel. (#2490)
* Bugfix for the metal elu kernel.

* Add a test.
2024-09-21 15:03:19 +02:00
af2104078f Metal commands refactoring (#2489)
* Split out the commands part of the metal device.

* Make most fields private.

* Move the allocator back.

* Rework the encoder provider type.
2024-09-21 13:18:42 +02:00
5fc4f17727 Adding Granite 7b Instruct model example (#2487)
* Adding Granite 7b Instruct model example

* Minor refactoring to make it a little more idiomatic

* Clippy fixes.

* * Adding a README with some information about supported Granite models
* Changing the default prompt to accomodate better the Language
  modality of the Granite 7b Instruct model

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-09-21 11:52:01 +02:00
c58c5d5b01 Add the mimi audio-tokenizer. (#2488)
* Add the mimi audio-tokenizer.

* Formatting tweaks.

* Add a full example.

* Use the transformers names.

* More renamings.

* Get encoding and decoding to work.

* Clippy fixes.
2024-09-20 14:31:20 -06:00
382c6b51af Improve error message (#2485) 2024-09-20 07:11:41 -06:00
6eea45a761 Add a couple cast metal kernels. (#2479) 2024-09-15 22:27:46 +02:00
ebf722b446 Export TensorIndexer public to candle users (#2477) 2024-09-13 22:21:57 +02:00
c09afc211c Fix for metal tanh. (#2475) 2024-09-13 07:08:36 +02:00
b60faebea4 Missing metal kernels. (#2474) 2024-09-12 13:58:50 +02:00
72d649058b Hook the MLX matmul kernels in candle-core. (#2473) 2024-09-12 13:52:59 +02:00
0cb0bd1dfa Add some metal gemm benchark. (#2471)
* Add some metal gemm benchark.

* More benchmarks.
2024-09-11 22:52:37 +02:00
afb6575835 Use the new MLX kernels to handle the BF16 matmul. (#2470) 2024-09-11 17:34:05 +02:00
5635650d38 Integrate the MLX gemm kernels (#2468)
* Include the MLX gemm kernels.

* Clippy lints.

* Export the gemm_f32 kernel.

* Add the f16/bf16 variants.

* Add the initial dispatch code.

* More plugging of the mlx kernels.

* Add a currently broken test.

* Tweaks.

* Bugfix + get the tests to pass.

* Enable the gemm bf16 tests.

* Add some randomized tests.

* Update candle-metal-kernels/src/lib.rs

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>

* More fixes.

* More clippy fixes.

---------

Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
2024-09-11 16:56:48 +02:00
13b2a8a4a0 Complete the missing backticks in the comments (#2469) 2024-09-11 16:37:05 +02:00
e3261216b1 Clippy fixes for 1.81.0. (#2461)
* Clippy fixes for 1.81.0.

* Another fix.
2024-09-05 23:46:55 +02:00
c02b7c3272 Fix FLUX.1 weights (#2457)
* fix FLUX.1 weights

* added flux1-dev.safetensors
2024-08-29 17:10:28 +02:00
86613c00e2 MobileCLIP models S1 and S2 (#2454)
* Allow loading images with given std and mean

* OpenCLIP text encoder component

* Two MobileCLIP models

* Clippy fixes.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-29 15:38:58 +02:00
29e25c458d FastViT fixes. (#2452)
* correct optional SE layer dimensions.
 * head_dim instead of num_heads is 32.
 * update test example output.
2024-08-28 11:20:09 +02:00
aafa24ed93 Update cudarc to 0.12. (#2451)
* Update cudarc to 0.12.

* Some cudnn tweaks.
2024-08-27 10:10:30 +02:00
fdc2622686 fix: qwen2 lm_head loading #2443 (#2445)
Co-authored-by: Yi Xu <xuyi@me.com>
2024-08-23 16:50:02 +02:00
ccdbe87639 Add FastViT model. (#2444) 2024-08-23 16:06:54 +02:00
2ec8729d51 Fix for parler-tts, do not add the last slice of padding tokens. (#2442)
* Fix for parler-tts, do not add the last slice of padding tokens.

* Support for the mini model.
2024-08-22 23:22:03 +02:00
e3c146ada6 silero-vad v5 example (#2321)
* silero-vad v5 example

This change adds an example of how to run silero-vad v5

* PR: rename 'vad' to 'silero-vad'

* Update README.md

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-08-22 22:50:42 +02:00
1e96b8b695 onnx: support negative index in Gather (#2440)
index_select does not support negative indexing, but
this change adds just enough workarounds in onnx to
allow evaluating silero-vad models (which make use of
negative indices).
2024-08-22 15:28:25 +02:00
a8288b7a72 onnx: workaround pow with negative base (#2439)
* onnx: workaround pow with negative base

rather than fully defining pow in the cpu backend (as in #2318),
this implements a much smaller change which is sufficient to evaluate silero-vad
onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so
evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`.

* PR: use Tensor::powf insead

powf correctly handles a negative base.
2024-08-22 13:34:53 +02:00
6070278a31 Bump the version to 0.6.1. (#2438) 2024-08-22 09:23:52 +02:00
b47c0bc475 Update README.md (#2435) 2024-08-19 09:34:24 +02:00
14fd2d97e0 Add a readme for the parler-tts example. (#2434)
* Add a readme for the parler-tts example.

* Remove the python decode script.

* mp4 tweaks.

* Another readme tweak.
2024-08-19 09:30:12 +02:00
31a1075f4b onnx: implement LSTM op (#2268)
use candle-nn LSTM
2024-08-19 09:06:17 +02:00
236b29ff15 Add the DAC model. (#2433)
* Add the DAC model.

* More quantization support.

* Handle DAC decoding.

* Plug the DAC decoding in parler-tts.
2024-08-19 08:59:51 +02:00
58197e1896 parler-tts support (#2431)
* Start sketching parler-tts support.

* Implement the attention.

* Add the example code.

* Fix the example.

* Add the description + t5 encode it.

* More of the parler forward pass.

* Fix the positional embeddings.

* Support random sampling in generation.

* Handle EOS.

* Add the python decoder.

* Proper causality mask.
2024-08-18 20:42:08 +02:00
736d8eb752 Stream tensor (#2429)
* Support Minus(u) for arbitrary values of u, e.g. Minus(3).

* Forces u to be strictly positive.

* Add StreamTensor.
2024-08-17 21:54:28 +02:00
7cff5898ec Support Minus(u) for arbitrary values of u, e.g. Minus(3). (#2428)
* Support Minus(u) for arbitrary values of u, e.g. Minus(3).

* Forces u to be strictly positive.
2024-08-17 21:29:01 +02:00
b75ef051cf Fix the marian tokenizer importer. (#2426)
* Fix the marian tokenizer importer.

* Ignore the python caches.
2024-08-17 20:58:40 +02:00
c1b9e07e35 Add support for gemma-2. (#2425)
* Add gemma-2.

* Support a couple more models.

* Sliding window support.

* Example + readme updates.

* Update the main readme.
2024-08-17 20:31:23 +02:00
69fdcfe96a Apply rustfmt. (#2421) 2024-08-16 18:57:14 +02:00
2b75dd9551 Fix build issue in EOS Token in llama-multiprocess (#2420) 2024-08-16 18:46:31 +02:00
53ce65f706 Clippy fixes. (#2415)
* Clippy fixes.

* Bump the web_sys required version.
2024-08-14 10:13:53 +02:00
68aa9c7320 Fix the device for the bert attention mask. (#2414) 2024-08-14 10:01:12 +02:00
35e5f31397 Add Based LLM from Hazy Research. (#2411) 2024-08-12 21:21:19 +02:00
d3fe989d08 Add documentation examples for Tensor::i and Tensor::narrow methods (#2308)
* Add documentation examples for `Tensor` methods

* Apply fmt.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-08-10 08:11:09 +02:00
14db029494 Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds

* NMS Test

* Soft nms w/ boxes removed below threshold

* Soft nms test

* No longer removing bounding boxes to fit Soft-NMS focus

* Initialize confidence

* Added comments

* Refactored out updating based on IOU/sigma

* Score_threshold -> confidence_threshold for clarity

* Remove bboxes below confidence threshold

* Softnms basic functionality test

* Softnms confidence decay test

* Softnms confidence threshold test

* Softnms no overlapping bbox test

* Testing confidence after no overlap test

* Single bbox and no bbox tests

* Signify test completion

* Handling result of test functions

* Checking all pairs of bboxes instead of a forward pass

* Equal confidence overlap test

* Clarified tests for implementation

* No longer dropping boxes, just setting to 0.0

* Formatted w/ cargo
2024-08-10 07:57:52 +02:00
6e6c1c99b0 Fix issues in the encodec example README.md (#2407)
Also squeeze the first dimension of the codes tensor in the example file to get the expected three dimensions.
2024-08-10 07:49:05 +02:00
b7d9af00cc fix: usage of actions/checkout@v2 (#2403)
* chore: changes from formatting on save

* fix: usage of `actions/checkout@v2`
2024-08-06 10:59:34 +02:00
59bbc0d287 Add the import script for the T5 tokenizer. (#2399) 2024-08-05 21:03:31 +02:00
dfdce2b602 Add the MMDiT model of Stable Diffusion 3 (#2397)
* add mmdit of stable diffusion 3

lint

add comments

* correct a misplaced comment

* fix cargo fmt

* fix clippy error

* use bail! instead of assert!

* use get_on_dim in splitting qkv
2024-08-05 19:26:15 +02:00
500c9f2882 add models support and example for THUDM/glm-4 (#2362)
* add models support and example for THUDM/glm-4

* fix the ci report

* fmt

* fix

* Update README.org

* Update README.org

* fmt

* Update README.org

* README.md add codegeex4

* README.md add glm4

* Typo.

* change expect into ?

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2024-08-05 17:48:09 +02:00
109 changed files with 13540 additions and 338 deletions

View File

@ -18,9 +18,9 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
steps:
- name: Checkout repository
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Install Rust
uses: actions-rs/toolchain@v1
@ -65,4 +65,4 @@ jobs:
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests
python -m pytest -s -v tests

View File

@ -1,6 +1,6 @@
on:
on:
push:
branches:
branches:
- main
pull_request:
@ -15,7 +15,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -34,7 +34,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -49,7 +49,7 @@ jobs:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -65,7 +65,7 @@ jobs:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal

6
.gitignore vendored
View File

@ -40,3 +40,9 @@ candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
__pycache__
out.safetensors
out.wav
bria.mp3
bria.safetensors
bria.wav

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.6.0"
version = "0.7.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
candle-nn = { path = "./candle-nn", version = "0.6.0" }
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
candle-nn = { path = "./candle-nn", version = "0.7.1" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "=0.11.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"

View File

@ -63,7 +63,9 @@ We also provide a some command line based examples using state of the art models
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
@ -118,6 +120,8 @@ We also provide a some command line based examples using state of the art models
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
@ -206,7 +210,7 @@ If you have an addition to this list, please submit a pull request.
- StarCoder, StarCoder2.
- Phi 1, 1.5, 2, and 3.
- Mamba, Minimal Mamba
- Gemma 2b and 7b.
- Gemma v1 2b and 7b+, v2 2b and 9b.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
@ -234,9 +238,10 @@ If you have an addition to this list, please submit a pull request.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Parler-TTS, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera.
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.

View File

@ -1,6 +1,6 @@
use crate::WithDType;
use cudarc;
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
use cudarc::cudnn::safe::{ConvForward, Cudnn};
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
use std::cell::RefCell;
use std::collections::HashMap;
@ -87,7 +87,7 @@ pub(crate) fn launch_conv2d<
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = Conv2dForward {
let conv2d = ConvForward {
conv: &conv,
x: &x,
w: &w,

View File

@ -174,6 +174,7 @@ impl Map1 for Im2Col1D {
}
}
#[allow(unused)]
struct Im2Col {
h_k: usize,
w_k: usize,
@ -183,6 +184,7 @@ struct Im2Col {
}
impl Im2Col {
#[allow(unused)]
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;

View File

@ -173,8 +173,8 @@ impl Device {
pub fn supports_bf16(&self) -> bool {
match self {
Self::Cuda(_) => true,
Self::Metal(_) | Self::Cpu => false,
Self::Cuda(_) | Self::Metal(_) => true,
Self::Cpu => false,
}
}

View File

@ -141,28 +141,117 @@ impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0., 1.],
/// [2., 3.],
/// [4., 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i(0)?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
///
/// let c = a.i(..2)?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f64>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i(1..)?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f64>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}
impl<A> IndexOp<(A,)> for Tensor
where
A: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0f32, 1.],
/// [2. , 3.],
/// [4. , 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i((0,))?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
///
/// let c = a.i((..2,))?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i((1..,))?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f32>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
self.index(&[a.into()])
}
}
#[allow(non_snake_case)]
impl<A, B> IndexOp<(A, B)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
///
/// let b = a.i((1, 0))?;
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
///
/// let c = a.i((..2, 1))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
///
/// let d = a.i((2.., ..))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
self.index(&[a.into(), b.into()])
}
}
macro_rules! index_op_tuple {
($($t:ident),+) => {
($doc:tt, $($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
#[doc=$doc]
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);

View File

@ -65,6 +65,7 @@ pub mod scalar;
pub mod shape;
mod sort;
mod storage;
pub mod streaming;
mod strided_index;
mod tensor;
mod tensor_cat;
@ -80,10 +81,11 @@ pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, Inp
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::IndexOp;
pub use indexer::{IndexOp, TensorIndexer};
pub use layout::Layout;
pub use shape::{Shape, D};
pub use storage::Storage;
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;

View File

@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
use std::sync::{Arc, Mutex, RwLock};
use super::MetalError;
@ -22,7 +22,73 @@ impl DeviceId {
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
pub(crate) struct Commands {
/// Single command queue for the entire device.
command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
command_buffer: CommandBuffer,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
command_buffer_index: usize,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
}
impl Commands {
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
Ok(Self {
command_queue,
command_buffer,
command_buffer_index: 0,
compute_per_buffer,
})
}
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
let mut command_buffer = self.command_buffer.to_owned();
let mut flushed = false;
if self.command_buffer_index > self.compute_per_buffer {
self.command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
self.command_buffer = command_buffer.clone();
self.command_buffer_index = 0;
flushed = true;
}
self.command_buffer_index += 1;
Ok((flushed, command_buffer))
}
pub fn wait_until_completed(&mut self) -> Result<()> {
match self.command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
self.command_buffer.commit();
self.command_buffer.wait_until_completed();
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
}
#[derive(Clone)]
pub struct MetalDevice {
@ -33,27 +99,8 @@ pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
/// Single command queue for the entire device.
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
pub(crate) commands: Arc<RwLock<Commands>>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
@ -67,9 +114,15 @@ pub struct MetalDevice {
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: AllocatedBuffers,
pub(crate) buffers: Arc<RwLock<BufferMap>>,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
/// Whether to use the MLX matmul kernels instead of the MFA ones.
pub(crate) use_mlx_mm: bool,
}
impl std::fmt::Debug for MetalDevice {
@ -87,6 +140,10 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
self.use_mlx_mm = use_mlx_mm
}
pub fn id(&self) -> DeviceId {
self.id
}
@ -95,44 +152,31 @@ impl MetalDevice {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
let mut commands = self.commands.write().map_err(MetalError::from)?;
let (flushed, command_buffer) = commands.command_buffer()?;
if flushed {
self.drop_unused_buffers()?
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
let mut commands = self.commands.write().map_err(MetalError::from)?;
commands.wait_until_completed()
}
pub fn kernels(&self) -> &Kernels {
@ -180,6 +224,7 @@ impl MetalDevice {
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
@ -210,40 +255,6 @@ impl MetalDevice {
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
@ -252,7 +263,7 @@ impl MetalDevice {
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
if let Some(b) = find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
@ -291,3 +302,23 @@ impl MetalDevice {
fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger
}
fn find_available_buffer(
size: NSUInteger,
option: MTLResourceOptions,
buffers: &BufferMap,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}

View File

@ -412,17 +412,42 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
} else {
let kernel_name = match (self.dtype, dtype) {
(DType::BF16, DType::F16) => "cast_bf16_f16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(DType::BF16, DType::I64) => "cast_bf16_i64_strided",
(DType::BF16, DType::U32) => "cast_bf16_u32_strided",
(DType::BF16, DType::U8) => "cast_bf16_u8_strided",
(DType::F16, DType::BF16) => "cast_f16_bf16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::F16, DType::I64) => "cast_f16_i64_strided",
(DType::F16, DType::U32) => "cast_f16_u32_strided",
(DType::F16, DType::U8) => "cast_f16_u8_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F32, DType::I64) => "cast_f32_i64_strided",
(DType::F32, DType::U32) => "cast_f32_u32_strided",
(DType::F32, DType::U8) => "cast_f32_u8_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::I64, DType::BF16) => "cast_i64_bf16_strided",
(DType::I64, DType::F16) => "cast_i64_f16_strided",
(DType::I64, DType::U32) => "cast_i64_u32_strided",
(DType::I64, DType::U8) => "cast_i64_u8_strided",
(DType::U32, DType::BF16) => "cast_u32_bf16_strided",
(DType::U32, DType::F16) => "cast_u32_f16_strided",
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::U32, DType::U8) => "cast_u32_u8_strided",
(DType::U32, DType::I64) => "cast_u32_i64_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(DType::U32, DType::U8) => "cast_u32_u8_strided",
(DType::U8, DType::BF16) => "cast_u8_bf16_strided",
(DType::U8, DType::F16) => "cast_u8_f16_strided",
(DType::U8, DType::F32) => "cast_u8_f32_strided",
(DType::U8, DType::I64) => "cast_u8_i64_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(left, right) => {
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
}
@ -1398,6 +1423,7 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
Ok(acc)
}
fn matmul(
&self,
rhs: &Self,
@ -1406,32 +1432,78 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
DType::BF16 => "bgemm",
dtype => {
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
}
};
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else if self.device.use_mlx_mm {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(
buffer,
self.device.clone(),
@ -1792,31 +1864,25 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
Ok(_) => true,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
let commands = device::Commands::new(command_queue)?;
Ok(Self {
id: DeviceId::new(),
device,
command_queue,
command_buffer,
command_buffer_index,
compute_per_buffer,
buffers,
commands: Arc::new(RwLock::new(commands)),
buffers: Arc::new(RwLock::new(HashMap::new())),
kernels,
seed,
use_mlx_mm,
})
}

View File

@ -304,6 +304,7 @@ impl Dim for usize {
pub enum D {
Minus1,
Minus2,
Minus(usize),
}
impl D {
@ -311,6 +312,7 @@ impl D {
let dim = match self {
Self::Minus1 => -1,
Self::Minus2 => -2,
Self::Minus(u) => -(*u as i32),
};
Error::DimOutOfRange {
shape: shape.clone(),
@ -327,6 +329,7 @@ impl Dim for D {
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
@ -336,6 +339,7 @@ impl Dim for D {
match self {
Self::Minus1 => Ok(rank),
Self::Minus2 if rank >= 1 => Ok(rank - 1),
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
_ => Err(self.out_of_range(shape, op)),
}
}

View File

@ -0,0 +1,206 @@
use crate::{Result, Shape, Tensor};
pub trait Dim: crate::shape::Dim + Copy {}
impl<T: crate::shape::Dim + Copy> Dim for T {}
/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
/// empty.
#[derive(Clone)]
pub struct StreamTensor(Option<Tensor>);
impl std::fmt::Debug for StreamTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
Some(t) => write!(f, "{:?}", t.shape()),
None => write!(f, "Empty"),
}
}
}
impl std::convert::From<Option<Tensor>> for StreamTensor {
fn from(value: Option<Tensor>) -> Self {
Self(value)
}
}
impl std::convert::From<Tensor> for StreamTensor {
fn from(value: Tensor) -> Self {
Self(Some(value))
}
}
impl std::convert::From<()> for StreamTensor {
fn from(_value: ()) -> Self {
Self(None)
}
}
impl StreamTensor {
pub fn empty() -> Self {
Self(None)
}
pub fn from_tensor(tensor: Tensor) -> Self {
Self(Some(tensor))
}
pub fn shape(&self) -> Option<&Shape> {
self.0.as_ref().map(|t| t.shape())
}
pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
let xs = match (&self.0, &rhs.0) {
(Some(lhs), Some(rhs)) => {
let xs = Tensor::cat(&[lhs, rhs], dim)?;
Some(xs)
}
(Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
(None, None) => None,
};
Ok(Self(xs))
}
pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
match &self.0 {
None => Ok(0),
Some(v) => v.dim(dim),
}
}
pub fn reset(&mut self) {
self.0 = None
}
pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
let t = match &self.0 {
None => None,
Some(t) => {
let seq_len = t.dim(dim)?;
if seq_len <= offset {
None
} else {
let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
Some(t)
}
}
};
Ok(Self(t))
}
/// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
/// returned in the first output and the remaining in the second output.
pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
match &self.0 {
None => Ok((Self::empty(), Self::empty())),
Some(t) => {
let seq_len = t.dim(dim)?;
let lhs_len = usize::min(seq_len, lhs_len);
if lhs_len == 0 {
Ok((Self::empty(), t.clone().into()))
} else {
let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
let rhs_len = seq_len - lhs_len;
let rhs = if rhs_len == 0 {
Self::empty()
} else {
Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
};
Ok((lhs, rhs))
}
}
}
}
pub fn as_option(&self) -> Option<&Tensor> {
self.0.as_ref()
}
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
match &self.0 {
None => Ok(Self::empty()),
Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
}
}
}
/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
/// some internal buffering so that enough data has been received for the module to be able to
/// perform some operations.
pub trait StreamingModule {
// TODO: Should we also have a flush method?
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
fn reset_state(&mut self);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BinOp {
Add,
Mul,
Sub,
Div,
}
#[derive(Debug, Clone)]
pub struct StreamingBinOp {
prev_lhs: StreamTensor,
prev_rhs: StreamTensor,
pub op: BinOp,
pub dim: crate::D,
}
impl StreamingBinOp {
pub fn new(op: BinOp, dim: crate::D) -> Self {
Self {
prev_lhs: StreamTensor::empty(),
prev_rhs: StreamTensor::empty(),
op,
dim,
}
}
pub fn reset_state(&mut self) {
self.prev_lhs.reset();
self.prev_rhs.reset();
}
pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
match self.op {
BinOp::Add => Tensor::add(lhs, rhs),
BinOp::Mul => Tensor::mul(lhs, rhs),
BinOp::Sub => Tensor::sub(lhs, rhs),
BinOp::Div => Tensor::div(lhs, rhs),
}
}
pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
let lhs_len = lhs.seq_len(self.dim)?;
let rhs_len = rhs.seq_len(self.dim)?;
let common_len = usize::min(lhs_len, rhs_len);
let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
let ys = match (lhs.0, rhs.0) {
(Some(lhs), Some(rhs)) => {
let ys = self.forward(&lhs, &rhs)?;
StreamTensor::from_tensor(ys)
}
(None, None) => StreamTensor::empty(),
(lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
};
self.prev_lhs = prev_lhs;
self.prev_rhs = prev_rhs;
Ok(ys)
}
}
/// Simple wrapper that doesn't do any buffering.
pub struct Map<T: crate::Module>(T);
impl<T: crate::Module> StreamingModule for Map<T> {
fn reset_state(&mut self) {}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
xs.apply(&self.0)
}
}

View File

@ -370,6 +370,15 @@ impl Tensor {
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [3.5, 3.5, 3.5, 3.5],
/// [3.5, 3.5, 3.5, 3.5],
/// ]);
/// # Ok::<(), candle_core::Error>(())
pub fn full<D: crate::WithDType, S: Into<Shape>>(
value: D,
shape: S,
@ -379,6 +388,13 @@ impl Tensor {
}
/// Creates a new 1D tensor from an iterator.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
device: &Device,
@ -390,12 +406,26 @@ impl Tensor {
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `1` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
Self::arange_step(start, end, D::one(), device)
}
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `step` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange_step<D: crate::WithDType>(
start: D,
end: D,
@ -441,6 +471,16 @@ impl Tensor {
/// Creates a new tensor initialized with values from the input vector. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
/// If the device is cpu, no data copy is made.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [1., 2., 3.],
/// [4., 5., 6.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
@ -451,6 +491,17 @@ impl Tensor {
/// Creates a new tensor initialized with values from the input slice. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
///```rust
/// use candle_core::{Tensor, Device};
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [2., 3., 4.],
/// [5., 6., 7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
@ -732,6 +783,30 @@ impl Tensor {
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
/// ```
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[
/// [0f32, 1., 2.],
/// [3. , 4., 5.],
/// [6. , 7., 8.]
/// ], &Device::Cpu)?;
///
/// let b = a.narrow(0, 1, 2)?;
/// assert_eq!(b.shape().dims(), &[2, 3]);
/// assert_eq!(b.to_vec2::<f32>()?, &[
/// [3., 4., 5.],
/// [6., 7., 8.]
/// ]);
///
/// let c = a.narrow(1, 1, 1)?;
/// assert_eq!(c.shape().dims(), &[3, 1]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [1.],
/// [4.],
/// [7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
@ -1950,7 +2025,11 @@ impl Tensor {
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
_ => {
bail!("not implemented yet")
bail!(
"not implemented yet, self.device: {:?}, device: {:?}",
self.device(),
device
)
}
};
let op = BackpropOp::new1(self, Op::ToDevice);

View File

@ -193,6 +193,19 @@ fn unary_op(device: &Device) -> Result<()> {
tensor.sign()?.to_vec1::<f32>()?,
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
);
let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = tensor.elu(2.)?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
// This test failed on metal prior to the following PR:
// https://github.com/huggingface/candle/pull/2490
let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, -1.7293, 0.0000, 3.0000]
);
Ok(())
}

View File

@ -67,6 +67,7 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
@ -101,6 +102,10 @@ required-features = ["candle-datasets"]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "mimi"
required-features = ["mimi"]
[[example]]
name = "encodec"
required-features = ["encodec"]
@ -108,3 +113,7 @@ required-features = ["encodec"]
[[example]]
name = "depth_anything_v2"
required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]

View File

@ -0,0 +1,20 @@
# candle-based
Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers.
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
## Running an example
```bash
$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100
Flying monkeys are a common sight in the wild, but they are also a threat to humans.
The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem.
"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California
```

View File

@ -0,0 +1,275 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::based::Model;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "360m")]
W360m,
#[value(name = "1b")]
W1b,
#[value(name = "1b-50b")]
W1b50b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "refs/pr/1")]
revision: String,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value = "360m")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::W360m => "hazyresearch/based-360m".to_string(),
Which::W1b => "hazyresearch/based-1b".to_string(),
Which::W1b50b => "hazyresearch/based-1b-50b".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let config_file = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
let repo = api.model("openai-community/gpt2".to_string());
let tokenizer_file = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
if args.which == Which::W1b50b {
vb = vb.pp("model");
};
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -7,7 +7,7 @@ quantization.
## Running one example
```bash
cargo run --example encodec --features symphonia --release -- code-to-audio \
cargo run --example encodec --features encodec --release -- code-to-audio \
candle-examples/examples/encodec/jfk-codes.safetensors \
jfk.wav
```

View File

@ -0,0 +1,20 @@
# candle-fastvit
[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189).
This candle implementation uses a pre-trained FastViT network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example fastvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which sa12
loaded image Tensor[dims 3, 256, 256; f32]
model built
mountain bike, all-terrain bike, off-roader: 52.67%
bicycle-built-for-two, tandem bicycle, tandem: 7.93%
unicycle, monocycle : 3.46%
maillot : 1.32%
crash helmet : 1.28%
```

View File

@ -0,0 +1,102 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::fastvit;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
T8,
T12,
S12,
SA12,
SA24,
SA36,
MA36,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::T8 => "t8",
Self::T12 => "t12",
Self::S12 => "s12",
Self::SA12 => "sa12",
Self::SA24 => "sa24",
Self::SA36 => "sa36",
Self::MA36 => "ma36",
};
format!("timm/fastvit_{}.apple_in1k", name)
}
fn config(&self) -> fastvit::Config {
match self {
Self::T8 => fastvit::Config::t8(),
Self::T12 => fastvit::Config::t12(),
Self::S12 => fastvit::Config::s12(),
Self::SA12 => fastvit::Config::sa12(),
Self::SA24 => fastvit::Config::sa24(),
Self::SA36 => fastvit::Config::sa36(),
Self::MA36 => fastvit::Config::ma36(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S12)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image(args.image, 256)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = fastvit::fastvit(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -147,8 +147,8 @@ fn run(args: Args) -> Result<()> {
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
Model::Dev => bf_repo.get("flux1-dev.sft")?,
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
@ -189,7 +189,7 @@ fn run(args: Args) -> Result<()> {
println!("latent img\n{img}");
let img = {
let model_file = bf_repo.get("ae.sft")?;
let model_file = bf_repo.get("ae.safetensors")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::autoencoder::Config::dev(),

View File

@ -0,0 +1,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE_MODEL = "google/t5-v1_1-xxl"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# The tokenizer will be saved in /tmp/tokenizer/tokenizer.json
tokenizer.save_pretrained("/tmp/tokenizer/")

View File

@ -1,27 +1,27 @@
# candle-gemma: 2b and 7b LLMs from Google DeepMind
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
models published by Google Deepmind with a 2b and a 7b variant.
In order to use the example below, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
models published by Google Deepmind with a 2b and a 7b variant for the first
version, and a 2b and a 9b variant for v2.
## Running the example
```bash
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
fn count_primes(max_n: usize) -> usize {
let mut primes = vec![true; max_n];
for i in 2..=max_n {
if primes[i] {
for j in i * i..max_n {
primes[j] = false;
}
}
}
primes.len()
}
$ cargo run --example gemma --features cuda -r -- \
--prompt "Here is a proof that square root of 2 is not rational: "
Here is a proof that square root of 2 is not rational:
Let us assume it to be rational. Then, we can write √2 = p/q where q ≠ 0 and p and q are integers with no common factors other than 1. Squaring both sides gives us (p/q)^2 = 2 or p^2/q^2 = 2. This implies that p^2 is divisible by 2, which means that p must be even. Let us write p = 2m where m is an integer. Substituting this in the above equation we get:
(p^2)/q^2 = 2 or (4m^2)/q^2 = 2 or q^2/2m^2 = 1 which implies that q^2 must be divisible by 2, and hence q is even. This contradicts our assumption that p and q have no common factors other than 1. Hence we conclude that √2 cannot be rational.
```
## Access restrictions
In order to use the v1 examples, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).

View File

@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::gemma::{Config, Model};
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -38,6 +39,46 @@ enum Which {
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
}
impl Which {
fn is_v1(&self) -> bool {
match self {
Self::Base2B
| Self::Base7B
| Self::Instruct2B
| Self::Instruct7B
| Self::InstructV1_1_2B
| Self::InstructV1_1_7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::CodeInstruct2B
| Self::CodeInstruct7B => true,
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
}
}
}
enum Model {
V1(Model1),
V2(Model2),
}
impl Model {
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
}
}
}
struct TextGeneration {
@ -191,7 +232,7 @@ struct Args {
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2b")]
#[arg(long, default_value = "2-2b")]
which: Which,
#[arg(long)]
@ -239,6 +280,10 @@ fn main() -> Result<()> {
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
@ -263,7 +308,6 @@ fn main() -> Result<()> {
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
@ -273,7 +317,15 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(args.use_flash_attn, &config, vb)?;
let model = if args.which.is_v1() {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
} else {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
};
println!("loaded the model in {:?}", start.elapsed());

View File

@ -0,0 +1,77 @@
* GLM4
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.
- [[https://github.com/THUDM/GLM4][Github]]
- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]]
** Running with ~cuda~
#+begin_src shell
cargo run --example glm4 --release --features cuda
#+end_src
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
#+end_src
** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换DFT的方法它广泛应用于信号处理、图像处理和数据分析等领域。
具体来说FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中我们通常需要知道信号的频率成分这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
```python
import numpy as np
# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
# 对该信号做FFT变换并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
```
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
#+end_src
This example will read prompt from stdin
* Citation
#+begin_src
@misc{glm2024chatglm,
title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools},
author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
year={2024},
eprint={2406.12793},
archivePrefix={arXiv},
primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}
#+end_src
#+begin_src
@misc{wang2023cogvlm,
title={CogVLM: Visual Expert for Pretrained Language Models},
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
year={2023},
eprint={2311.03079},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
#+end_src

View File

@ -0,0 +1,255 @@
use candle_transformers::models::glm4::*;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
dtype,
}
}
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
use std::io::Write;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
self.model.reset_kv_cache(); // clean the cache
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,
#[arg(long)]
cpu: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 8192)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.2)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.6),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
let start = std::time::Instant::now();
let config = Config::glm4();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,20 @@
# candle-granite LLMs from IBM Research
[Granite](https://www.ibm.com/granite) is a family of Large Language Models built for business, to help drive trust and scalability in AI-driven applications.
## Running the example
```bash
$ cargo run --example granite --features metal -r -- --model-type "granite7b-instruct" \
--prompt "Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind"
Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind competitors.
In recent years, there has been significant interest in quantum computing due to its potential to revolutionize various fields, including drug discovery, cryptography, and optimization problems. Quantum computers, which leverage the principles of quantum mechanics, differ fundamentally from classical computers. Here are some of the key differences:
```
## Supported Models
There are two different modalities for the Granite family models: Language and Code.
### Granite for language
1. [Granite 7b Instruct](https://huggingface.co/ibm-granite/granite-7b-instruct)

View File

@ -0,0 +1,251 @@
// An implementation of different Granite models https://www.ibm.com/granite
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use candle_transformers::models::granite as model;
use model::{Granite, GraniteConfig};
use std::time::Instant;
const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum GraniteModel {
Granite7bInstruct,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
#[arg(long)]
no_kv_cache: bool,
/// The initial prompt.
#[arg(long)]
prompt: Option<String>,
/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long, default_value = "granite7b-instruct")]
model_type: GraniteModel,
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (granite, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.model_type {
GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
let filenames = match args.model_type {
GraniteModel::Granite7bInstruct => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(
Granite::load(vb, &config)?,
tokenizer_filename,
cache,
config,
)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config.eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::GraniteEosToks::Single)
});
let default_prompt = match args.model_type {
GraniteModel::Granite7bInstruct => DEFAULT_PROMPT,
};
let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("Starting the inference loop:");
print!("{prompt}");
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
let use_cache_kv = cache.use_kv_cache;
(0..args.sample_len)
.inspect(|index| {
if *index == 1 {
start_gen = Instant::now();
}
})
.try_for_each(|index| -> Result<()> {
let (context_size, context_index) = if use_cache_kv && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = granite
.forward(&input, context_index, &mut cache)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);
if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id {
if next_token == eos_tok_id {
return Err(E::msg("EOS token found"));
}
} else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id {
if eos_ids.contains(&next_token) {
return Err(E::msg("EOS token found"));
}
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
Ok(())
})
.unwrap_or(());
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
(token_generated - 1) as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -14,6 +14,7 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::llama::LlamaEosToks;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -219,9 +220,16 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
if Some(next_token) == config.eos_token_id {
break;
match config.eos_token_id {
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
_ => (),
}
if rank == 0 {
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");

View File

@ -43,6 +43,14 @@ def import_protobuf(error_message=""):
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
if add_prefix_space:
prepend_scheme = "always"
if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
prepend_scheme = "first"
else:
prepend_scheme = "never"
return prepend_scheme
class SentencePieceExtractor:
"""
@ -519,13 +527,15 @@ class SpmConverter(Converter):
)
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
def post_processor(self):
return None
def decoder(self, replacement, add_prefix_space):
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
@ -636,7 +646,8 @@ class DebertaV2Converter(SpmConverter):
list_pretokenizers = []
if self.original_tokenizer.split_by_punct:
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
return pre_tokenizers.Sequence(list_pretokenizers)
def normalizer(self, proto):
@ -929,10 +940,11 @@ class PegasusConverter(SpmConverter):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
]
)

View File

@ -0,0 +1,20 @@
# candle-mimi
[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
compression model using an encoder/decoder architecture with residual vector
quantization. The candle implementation supports streaming meaning that it's
possible to encode or decode a stream of audio tokens on the flight to provide
low latency interaction with an audio model.
## Running one example
Generating some audio tokens from an audio files.
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
```
And decoding the audio tokens back into a sound file.
```bash
cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
```

View File

@ -0,0 +1,275 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -0,0 +1,165 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::mimi::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some mimi tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some mimi tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
/// to the encoder/decoder one at a time.
#[arg(long)]
streaming: Option<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::v0_1(None);
let mut model = Model::new(config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
codes.get("codes").expect("no codes in input file").clone()
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
}
};
match args.streaming {
Some(chunk_size) => {
let mut code_chunks = vec![];
for pcm in pcm.chunks(chunk_size) {
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
let code_chunk = model.encode(&pcm)?;
code_chunks.push(code_chunk)
}
Tensor::cat(&code_chunks, candle::D::Minus1)?
}
None => {
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
}
}
};
println!("codes shape: {:?}", codes.shape());
model.reset_state();
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = match args.streaming {
Some(chunk_size) => {
let seq_len = codes.dim(candle::D::Minus1)?;
let mut pcm_chunks = vec![];
for chunk_start in (0..seq_len).step_by(chunk_size) {
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
let pcm = model.decode_step(&codes.into())?;
if let Some(pcm) = pcm.as_option() {
pcm_chunks.push(pcm.clone())
}
}
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
}
None => model.decode(&codes)?,
};
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
}
Ok(())
}

View File

@ -0,0 +1,28 @@
# candle-mobileclip
MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders.
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049)
## Running on an example on cpu
```
$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0025% Text: a cycling race
Probability: 0.0004% Text: a photo of two cats
Probability: 99.9971% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 99.9974% Text: a cycling race
Probability: 0.0024% Text: a photo of two cats
Probability: 0.0002% Text: a robot holding a candle
```

View File

@ -0,0 +1,192 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::mobileclip;
use tokenizers::Tokenizer;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S1,
S2,
}
impl Which {
fn model_name(&self) -> String {
let name = match self {
Self::S1 => "S1",
Self::S2 => "S2",
};
format!("apple/MobileCLIP-{}-OpenCLIP", name)
}
fn config(&self) -> mobileclip::MobileClipConfig {
match self {
Self::S1 => mobileclip::MobileClipConfig::s1(),
Self::S2 => mobileclip::MobileClipConfig::s2(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
#[arg(value_enum, long, default_value_t=Which::S1)]
which: Which,
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
image_size,
&[0.0, 0.0, 0.0],
&[1.0, 1.0, 1.0],
)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_name = args.which.model_name();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};
let tokenizer = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = &args.which.config();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
};
let model = mobileclip::MobileClipModel::new(vb, config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {}", p, vec_seq[i]);
}
}
Ok(())
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
// let pad_id = *tokenizer
// .get_vocab(true)
// .get("<|endoftext|>")
// .ok_or(E::msg("No pad token"))?;
// The model does not work well if the text is padded using the <|endoftext|> token, using 0
// as the original OpenCLIP code.
let pad_id = 0;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -72,8 +72,9 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())?
.to_device(&device)?;
let image =
candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)?
.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {

View File

@ -284,11 +284,11 @@ impl MusicgenDecoder {
};
let embed_dim = cfg.vocab_size + 1;
let embed_tokens = (0..cfg.num_codebooks)
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
.map(|i| embedding(embed_dim, h, vb.pp(format!("embed_tokens.{i}"))))
.collect::<Result<Vec<_>>>()?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
let layers = (0..cfg.num_hidden_layers)
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
.map(|i| MusicgenDecoderLayer::load(vb.pp(format!("layers.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
Ok(Self {
@ -341,7 +341,7 @@ impl MusicgenForCausalLM {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(format!("lm_heads.{i}"))))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,

View File

@ -0,0 +1,23 @@
# candle-parler-tts
[Parler-TTS](https://huggingface.co/parler-tts/parler-tts-large-v1) is a large
text-to-speech model with 2.2B parameters trained on ~45K hours of audio data.
The voice can be controlled by a text prompt.
## Run an example
```bash
cargo run --example parler-tts -r -- \
--prompt "Hey, how are you doing today?"
```
In order to specify some prompt for the voice, use the `--description` argument.
```bash
cargo run --example parler-tts -r -- \
--prompt "Hey, how are you doing today?" \
--description "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
```
https://github.com/user-attachments/assets/1b16aeac-70a3-4803-8589-4563279bba33

Binary file not shown.

View File

@ -0,0 +1,206 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::parler_tts::{Config, Model};
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
#[arg(
long,
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
)]
description: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.0)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 5000)]
sample_len: usize,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
quantized: bool,
/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value_t = 512)]
max_steps: usize,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long, default_value = "large-v1")]
which: Which,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "large-v1")]
LargeV1,
#[value(name = "mini-v1")]
MiniV1,
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::LargeV1 => "parler-tts/parler-tts-large-v1".to_string(),
Which::MiniV1 => "parler-tts/parler-tts-mini-v1".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => match args.which {
Which::MiniV1 => vec![repo.get("model.safetensors")?],
Which::LargeV1 => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
let mut model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let description_tokens = tokenizer
.encode(args.description, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
let prompt_tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
let lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
args.top_p,
);
println!("starting generation...");
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
println!("generated codes\n{codes}");
let codes = codes.to_dtype(DType::I64)?;
codes.save_safetensors("codes", "out.safetensors")?;
let codes = codes.unsqueeze(0)?;
let pcm = model
.audio_encoder
.decode_codes(&codes.to_device(&device)?)?;
println!("{pcm}");
let pcm = pcm.i((0, 0))?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
Ok(())
}

View File

@ -0,0 +1,12 @@
# silero-vad: Voice Activity Detection
[Silero VAD (v5)](https://github.com/snakers4/silero-vad) detects voice activity in streaming audio.
This example uses the models available in the hugging face [onnx-community/silero-vad](https://huggingface.co/onnx-community/silero-vad).
## Running the example
```bash
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
```

View File

@ -0,0 +1,199 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::Parser;
use candle::{DType, Tensor};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "silero")]
Silero,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum SampleRate {
#[value(name = "8000")]
Sr8k,
#[value(name = "16000")]
Sr16k,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
input: Option<String>,
#[arg(long)]
sample_rate: SampleRate,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// The model to use.
#[arg(long, default_value = "silero")]
which: Which,
}
/// an iterator which reads consecutive frames of le i16 values from a reader
struct I16Frames<R> {
rdr: R,
buf: Box<[u8]>,
len: usize,
eof: bool,
}
impl<R> I16Frames<R> {
fn new(rdr: R, frame_size: usize) -> Self {
I16Frames {
rdr,
buf: vec![0; frame_size * std::mem::size_of::<i16>()].into_boxed_slice(),
len: 0,
eof: false,
}
}
}
impl<R: std::io::Read> Iterator for I16Frames<R> {
type Item = std::io::Result<Vec<f32>>;
fn next(&mut self) -> Option<Self::Item> {
if self.eof {
return None;
}
self.len += match self.rdr.read(&mut self.buf[self.len..]) {
Ok(0) => {
self.eof = true;
0
}
Ok(n) => n,
Err(e) => return Some(Err(e)),
};
if self.eof || self.len == self.buf.len() {
let buf = self.buf[..self.len]
.chunks(2)
.map(|bs| match bs {
[a, b] => i16::from_le_bytes([*a, *b]),
_ => unreachable!(),
})
.map(|i| i as f32 / i16::MAX as f32)
.collect();
self.len = 0;
Some(Ok(buf))
} else {
self.next()
}
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let start = std::time::Instant::now();
let model_id = match &args.model_id {
Some(model_id) => std::path::PathBuf::from(model_id),
None => match args.which {
Which::Silero => hf_hub::api::sync::Api::new()?
.model("onnx-community/silero-vad".into())
.get("onnx/model.onnx")?,
// TODO: candle-onnx doesn't support Int8 dtype
// Which::SileroQuantized => hf_hub::api::sync::Api::new()?
// .model("onnx-community/silero-vad".into())
// .get("onnx/model_quantized.onnx")?,
},
};
let (sample_rate, frame_size, context_size): (i64, usize, usize) = match args.sample_rate {
SampleRate::Sr8k => (8000, 256, 32),
SampleRate::Sr16k => (16000, 512, 64),
};
println!("retrieved the files in {:?}", start.elapsed());
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let model = candle_onnx::read_file(model_id)?;
println!("loaded the model in {:?}", start.elapsed());
let start = std::time::Instant::now();
struct State {
frame_size: usize,
sample_rate: Tensor,
state: Tensor,
context: Tensor,
}
let mut state = State {
frame_size,
sample_rate: Tensor::new(sample_rate, &device)?,
state: Tensor::zeros((2, 1, 128), DType::F32, &device)?,
context: Tensor::zeros((1, context_size), DType::F32, &device)?,
};
let mut res = vec![];
for chunk in I16Frames::new(std::io::stdin().lock(), state.frame_size) {
let chunk = chunk.unwrap();
if chunk.len() < state.frame_size {
continue;
}
let next_context = Tensor::from_slice(
&chunk[state.frame_size - context_size..],
(1, context_size),
&device,
)?;
let chunk = Tensor::from_vec(chunk, (1, state.frame_size), &device)?;
let chunk = Tensor::cat(&[&state.context, &chunk], 1)?;
let inputs = std::collections::HashMap::from_iter([
("input".to_string(), chunk),
("sr".to_string(), state.sample_rate.clone()),
("state".to_string(), state.state.clone()),
]);
let out = candle_onnx::simple_eval(&model, inputs).unwrap();
let out_names = &model.graph.as_ref().unwrap().output;
let output = out.get(&out_names[0].name).unwrap().clone();
state.state = out.get(&out_names[1].name).unwrap().clone();
assert_eq!(state.state.dims(), &[2, 1, 128]);
state.context = next_context;
let output = output.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(output.len(), 1);
let output = output[0];
println!("vad chunk prediction: {output}");
res.push(output);
}
println!("calculated prediction in {:?}", start.elapsed());
let res_len = res.len() as f32;
let prediction = res.iter().sum::<f32>() / res_len;
println!("vad average prediction: {prediction}");
Ok(())
}

View File

@ -123,7 +123,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
let padding = if pad != 0 { (size - 1) / 2 } else { 0 };
let (bn, bias) = match b.parameters.get("batch_normalize") {
Some(p) if p.parse::<usize>()? != 0 => {
let bn = batch_norm(filters, 1e-5, vb.pp(&format!("batch_norm_{index}")))?;
let bn = batch_norm(filters, 1e-5, vb.pp(format!("batch_norm_{index}")))?;
(Some(bn), false)
}
Some(_) | None => (None, true),
@ -135,9 +135,9 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
dilation: 1,
};
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
} else {
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
};
let leaky = match activation {
"leaky" => true,

View File

@ -161,7 +161,7 @@ impl C2f {
let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?;
let mut bottleneck = Vec::with_capacity(n);
for idx in 0..n {
let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?;
let b = Bottleneck::load(vb.pp(format!("bottleneck.{idx}")), c, c, shortcut)?;
bottleneck.push(b)
}
Ok(Self {

View File

@ -1,23 +1,42 @@
use candle::{Device, Result, Tensor};
/// Loads an image from disk using the image crate at the requested resolution.
// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
/// Loads an image from disk using the image crate at the requested resolution,
/// using the given std and mean parameters.
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
p: P,
res: usize,
mean: &[f32; 3],
std: &[f32; 3],
) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(res, res, image::imageops::FilterType::Triangle);
.resize_to_fill(
res as u32,
res as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)?
.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
/// Loads an image from disk using the image crate at the requested resolution.
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: usize) -> Result<Tensor> {
load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD)
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 224, 224). imagenet normalization is applied.
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.6.0"
version = "0.7.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.6.0"
version = "0.7.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.6.0"
version = "0.7.1"
edition = "2021"
description = "Metal kernels for Candle"
@ -17,9 +17,12 @@ thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
clap = { version = "4.2.4", features = ["derive"] }
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
anyhow = "1"
rand = "0.8.5"
rand_distr = "0.4.3"

View File

@ -0,0 +1,136 @@
use anyhow::Result;
use candle_metal_kernels::GemmDType;
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
use clap::{Parser, Subcommand};
use half::f16;
fn run_gemm(f32: bool, n: usize) -> Result<()> {
const WARMUP_ITERS: usize = 2;
const MIN_DUR: f64 = 4.;
let device = metal::Device::system_default().unwrap();
let (b, m, n, k) = (1, n, n, n);
let kernels = candle_metal_kernels::Kernels::new();
let command_queue = device.new_command_queue();
let options = metal::MTLResourceOptions::StorageModeManaged;
let (lhs, rhs) = if f32 {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&rhs) as u64,
options,
);
(lhs, rhs)
} else {
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&rhs) as u64,
options,
);
(lhs, rhs)
};
let (dtype, name, sizeof) = if f32 {
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
} else {
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
};
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
for mlx in [false, true] {
let mut sum_dt = 0f64;
let mut iters = 0usize;
for idx in 0.. {
let command_buffer = command_queue.new_command_buffer();
let start_time = std::time::Instant::now();
if mlx {
candle_metal_kernels::call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
} else {
candle_metal_kernels::call_gemm(
&device,
command_buffer,
&kernels,
name,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
}
command_buffer.commit();
command_buffer.wait_until_completed();
let dt = start_time.elapsed().as_secs_f64();
if idx < WARMUP_ITERS {
continue;
}
sum_dt += dt;
iters += 1;
if sum_dt > MIN_DUR {
break;
}
}
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
let mlx = if mlx { "MLX" } else { "MFA" };
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
}
Ok(())
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Gemm,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The benchmark to be run.
#[command(subcommand)]
task: Task,
}
fn main() -> Result<()> {
let args = Args::parse();
match args.task {
Task::Gemm => {
for f32 in [false, true] {
for n in [512, 1024, 2048, 4096] {
run_gemm(f32, n)?;
}
}
}
}
Ok(())
}

View File

@ -105,7 +105,7 @@ kernel void FN_NAME##_strided( \
return; \
} \
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
} \

View File

@ -11,33 +11,35 @@ pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const RANDOM: &str = include_str!("random.metal");
const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Indexing,
Unary,
Binary,
Ternary,
Cast,
Reduce,
Mfa,
Conv,
Random,
Gemm,
Indexing,
Mfa,
Quantized,
Random,
Reduce,
Sort,
Ternary,
Unary,
}
pub mod copy2d {
@ -191,16 +193,17 @@ impl Kernels {
fn get_library_source(&self, source: Source) -> &'static str {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Binary => BINARY,
Source::Ternary => TERNARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
Source::Random => RANDOM,
Source::Reduce => REDUCE,
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -2178,5 +2181,181 @@ pub fn call_arg_sort(
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum GemmDType {
BF16,
F16,
F32,
}
#[allow(clippy::too_many_arguments)]
pub fn call_mlx_gemm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct GemmParams {
m: i32,
n: i32,
k: i32,
lda: i32,
ldb: i32,
ldd: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_a: isize,
batch_stride_b: isize,
batch_stride_d: isize,
swizzle_log: i32,
gemm_k_iterations_aligned: i32,
batch_ndim: i32,
}
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, false)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
// rhs has shape b, k, n
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, false)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
let constants = Some(ConstantValues::new(vec![
(10, Value::Bool(/* has_batch */ b > 1)),
(100, Value::Bool(/* use_out_source */ false)),
(110, Value::Bool(/* do_axpby */ false)),
(200, Value::Bool(/* align_m */ m % bm == 0)),
(201, Value::Bool(/* align_n */ n % bn == 0)),
(202, Value::Bool(/* align_k */ k % bk == 0)),
(300, Value::Bool(/* do_gather */ false)),
]));
let swizzle_log = 0;
let tile = 1 << swizzle_log;
let tn = n.div_ceil(bn);
let tm = m.div_ceil(bm);
let tn = tn * tile;
let tm = tm.div_ceil(tile);
let batch_stride_a = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3]
} else {
m * k
};
let batch_stride_b = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3]
} else {
n * k
};
let gemm_params = GemmParams {
m: m as i32,
n: n as i32,
k: k as i32,
lda,
ldb,
ldd: n as i32,
tiles_n: tn as i32,
tiles_m: tm as i32,
swizzle_log,
batch_stride_a: batch_stride_a as isize,
batch_stride_b: batch_stride_b as isize,
batch_stride_d: (m * n) as isize,
batch_ndim: 1i32,
gemm_k_iterations_aligned: (k / bk) as i32,
};
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
// TODO(laurent): generate the name
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
let name = match (dtype, a_trans, b_trans) {
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
};
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(3, Some(output), 0);
encoder.set_bytes(
4,
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const c_void,
);
encoder.set_bytes(
6, // batch_shape
std::mem::size_of::<i32>() as u64,
&(b as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
);
let grid_size = MTLSize {
width: tn as u64,
height: tm as u64,
depth: /* batch_size_out */ b as u64,
};
let group_size = MTLSize {
width: 32,
height: wn,
depth: wm,
};
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
Ok(())
}
#[cfg(test)]
mod tests;

File diff suppressed because it is too large Load Diff

View File

@ -329,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
#[test]
fn cast_f32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -360,7 +360,7 @@ fn cast_f32() {
#[test]
fn cast_f16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -391,7 +391,7 @@ fn cast_f16() {
#[test]
fn cast_bf16() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -422,7 +422,7 @@ fn cast_bf16() {
#[test]
fn cast_u32() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -453,7 +453,7 @@ fn cast_u32() {
#[test]
fn cast_u8() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -484,7 +484,7 @@ fn cast_u8() {
#[test]
fn cast_i64() {
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -911,7 +911,7 @@ fn softmax() {
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@ -922,7 +922,7 @@ fn softmax() {
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
);
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1045,14 +1045,15 @@ fn where_cond_u32_f32() {
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
#[allow(clippy::too_many_arguments)]
fn run_gemm<T: Clone>(
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
lhs_stride: &[usize],
lhs_offset: usize,
rhs: &[T],
rhs_stride: Vec<usize>,
rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
let device = device();
@ -1079,10 +1080,10 @@ fn run_gemm<T: Clone>(
&kernels,
name,
(b, m, n, k),
&lhs_stride,
lhs_stride,
lhs_offset,
&lhs,
&rhs_stride,
rhs_stride,
rhs_offset,
&rhs,
&output,
@ -1105,10 +1106,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1125,10 +1126,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1150,10 +1151,10 @@ fn gemm() {
"sgemm",
(1, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
12 * 4,
);
assert_eq!(
@ -1172,10 +1173,10 @@ fn gemm() {
"bgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1194,10 +1195,10 @@ fn gemm() {
"hgemm",
(b, m, n, k),
&lhs,
lhs_stride,
&lhs_stride,
0,
&rhs,
rhs_stride,
&rhs_stride,
0,
);
assert_eq!(
@ -1206,6 +1207,204 @@ fn gemm() {
);
}
#[allow(clippy::too_many_arguments)]
fn run_mlx_gemm<T: Clone>(
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: &[usize],
lhs_offset: usize,
rhs: &[T],
rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
lhs_stride,
lhs_offset,
&lhs,
rhs_stride,
rhs_offset,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
use rand::SeedableRng;
use rand_distr::Distribution;
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
let v1: Vec<f32> = run_mlx_gemm(
dtype,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
let v2: Vec<f32> = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
for (a, b) in v1.iter().zip(v2.iter()) {
let diff = (a - b).abs();
assert_eq!((diff * 1e4).round(), 0.)
}
}
#[test]
fn mlx_vs_mfa() {
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
}
#[test]
fn mlx_gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
// OFFSET
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_mlx_gemm(
GemmDType::F32,
(1, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
12 * 4,
);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
// bgemm sanity test
{
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::BF16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
{
// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::F16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_f16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
@ -1280,7 +1479,7 @@ fn random() {
variance.sqrt()
}
let shape = vec![1024, 10];
let shape = [1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
@ -1636,7 +1835,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@ -1656,7 +1855,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
let expected = vec![5.0, 7.0, 13.0, 15.0]
let expected = [5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@ -1679,7 +1878,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1699,7 +1898,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
let expected = vec![5.0, 7.0, 13.0, 15.0]
let expected = [5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1818,7 +2017,7 @@ fn avg_pool2d_f16() {
&strides,
"avg_pool2d_f16",
);
let expected = vec![
let expected = [
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@ -1843,7 +2042,7 @@ fn avg_pool2d_bf16() {
&strides,
"avg_pool2d_bf16",
);
let expected = vec![
let expected = [
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@ -1981,14 +2180,14 @@ fn conv_transpose1d_f32() {
#[test]
fn conv_transpose1d_f16() {
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
@ -2009,7 +2208,7 @@ fn conv_transpose1d_f16() {
"conv_transpose1d_f16",
);
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
let expected = [1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@ -2018,14 +2217,14 @@ fn conv_transpose1d_f16() {
#[test]
fn conv_transpose1d_bf16() {
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
@ -2046,7 +2245,7 @@ fn conv_transpose1d_bf16() {
"conv_transpose1d_bf16",
);
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
let expected = [1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();

View File

@ -56,7 +56,7 @@ template <typename T> METAL_FUNC T gelu(T x) {
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
@ -154,7 +154,6 @@ UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
@ -164,6 +163,11 @@ UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
// tanh may create NaN on large values, e.g. 45 rather than outputing 1.
// This has been an issue for the encodec example.
UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided);
UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
@ -185,7 +189,6 @@ BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
@ -193,5 +196,7 @@ BFLOAT_UNARY_OP(sigmoid)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
COPY2D(copy2d_bf16, bfloat)
#endif

View File

@ -165,20 +165,25 @@ pub trait EncoderProvider {
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
fn encoder(&self) -> Self::Encoder<'_>;
}
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
pub struct WrappedEncoder<'a> {
inner: &'a ComputeCommandEncoderRef,
end_encoding_on_drop: bool,
}
impl<'a> Drop for WrappedEncoder<'a> {
fn drop(&mut self) {
self.0.end_encoding()
if self.end_encoding_on_drop {
self.inner.end_encoding()
}
}
}
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
&self.0
self.inner
}
}
@ -186,8 +191,11 @@ impl EncoderProvider for &metal::CommandBuffer {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
WrappedEncoder(self.new_compute_command_encoder())
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self.new_compute_command_encoder(),
end_encoding_on_drop: true,
}
}
}
@ -195,7 +203,22 @@ impl EncoderProvider for &metal::CommandBufferRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
WrappedEncoder(self.new_compute_command_encoder())
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self.new_compute_command_encoder(),
end_encoding_on_drop: true,
}
}
}
impl EncoderProvider for &ComputeCommandEncoderRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self,
end_encoding_on_drop: false,
}
}
}

View File

@ -1,4 +1,4 @@
use candle::{Result, Tensor};
use candle::{Device, Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
@ -145,3 +145,225 @@ impl KvCache {
self.v.reset();
}
}
#[derive(Debug, Clone)]
pub struct RotatingCache {
all_data: Option<Tensor>,
dim: usize,
// `offset` is the current write index in the buffer
offset: usize,
// The total size of the sequence seen so far.
current_seq_len: usize,
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
// sequence to grow past this limit.
max_seq_len: usize,
}
impl RotatingCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
offset: 0,
current_seq_len: 0,
max_seq_len,
}
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => {
if self.current_seq_len >= self.max_seq_len {
Some(d.clone())
} else {
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
}
}
};
Ok(data)
}
pub fn reset(&mut self) {
self.offset = 0;
self.current_seq_len = 0;
self.all_data = None;
}
pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();
self.current_seq_len += seq_len;
if seq_len >= self.max_seq_len {
let to_copy = src
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
.contiguous()?;
ad.slice_set(&to_copy, self.dim, 0)?;
self.offset = 0;
// Here we return `src` rather than `ad` so that all the past can be used.
Ok(src.clone())
} else {
let rem_len = self.max_seq_len - self.offset;
if seq_len <= rem_len {
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
self.offset = (self.offset + seq_len) % self.max_seq_len;
} else {
// We have to make two copies here as we go over the boundary of the cache.
if rem_len > 0 {
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
ad.slice_set(&src1, self.dim, self.offset)?;
}
let src2 = src
.narrow(self.dim, rem_len, seq_len - rem_len)?
.contiguous()?;
ad.slice_set(&src2, self.dim, 0)?;
self.offset = seq_len - rem_len;
}
if self.current_seq_len >= self.max_seq_len {
Ok(ad.clone())
} else {
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
}
}
}
fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let upd_offset = (self.offset + size1) % self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|pos_src| {
// The absolute position of the elements that will get added to the cache.
let pos_src = self.current_seq_len + pos_src;
(0..size2).map(move |pos_cache_rel| {
// The absolute position of the cache elements after the addition.
let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
let pos_cache = if pos_cache_rel < upd_offset {
pos_cache
} else {
pos_cache - self.max_seq_len
};
u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
let mask = if seq_len == 1 {
None
} else {
let mask = if seq_len < self.max_seq_len {
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
self.get_mask_rel(seq_len, cache_out_len, device)?
} else {
self.get_mask_abs(seq_len, seq_len, device)?
};
Some(mask)
};
Ok(mask)
}
}
#[derive(Debug, Clone)]
pub struct RotatingKvCache {
k: RotatingCache,
v: RotatingCache,
}
impl RotatingKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, max_seq_len);
let v = RotatingCache::new(dim, max_seq_len);
Self { k, v }
}
pub fn k_cache(&self) -> &RotatingCache {
&self.k
}
pub fn v_cache(&self) -> &RotatingCache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
let out_k = self.k.append(k)?;
let out_v = self.v.append(v)?;
Ok((out_k, out_v))
}
pub fn offset(&self) -> usize {
self.k.offset()
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
self.k.attn_mask(seq_len, device)
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}

View File

@ -55,6 +55,10 @@ pub struct LSTMState {
}
impl LSTMState {
pub fn new(h: Tensor, c: Tensor) -> Self {
LSTMState { h, c }
}
/// The hidden state vector, which is also the output of the LSTM.
pub fn h(&self) -> &Tensor {
&self.h

110
candle-nn/tests/kv_cache.rs Normal file
View File

@ -0,0 +1,110 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{Device, Result, Tensor};
#[test]
fn kv_cache() -> Result<()> {
let mut cache = candle_nn::kv_cache::Cache::new(0, 16);
for _ in [0, 1] {
assert_eq!(cache.current_seq_len(), 0);
let data = cache.current_data()?;
assert!(data.is_none());
let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
cache.append(&t)?;
let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3.]);
let t = Tensor::new(&[4f32], &Device::Cpu)?;
cache.append(&t)?;
let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4.]);
let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?;
cache.append(&t)?;
let data = cache.current_data()?.unwrap();
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]);
assert_eq!(cache.current_seq_len(), 8);
cache.reset();
}
Ok(())
}
#[test]
fn rotating_kv_cache() -> Result<()> {
let mut cache = candle_nn::kv_cache::RotatingCache::new(0, 6);
for _ in [0, 1] {
assert_eq!(cache.offset(), 0);
assert_eq!(cache.current_seq_len(), 0);
let data = cache.current_data()?;
assert!(data.is_none());
let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
let t = Tensor::new(&[4.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 3., 4., 0., 5.]);
assert_eq!(cache.current_seq_len(), 8);
assert_eq!(cache.offset(), 2);
let t = Tensor::new(&[8.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 4., 0., 5.]);
assert_eq!(cache.current_seq_len(), 9);
assert_eq!(cache.offset(), 3);
let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 9., 10., 11.]);
assert_eq!(cache.current_seq_len(), 12);
assert_eq!(cache.offset(), 0);
let t = Tensor::new(&[12.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [12., 7., 8., 9., 10., 11.]);
assert_eq!(cache.current_seq_len(), 13);
assert_eq!(cache.offset(), 1);
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
);
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 22);
assert_eq!(cache.offset(), 0);
let mask = cache.attn_mask(1, &Device::Cpu)?;
assert!(mask.is_none());
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let t = Tensor::new(&[42.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 23);
assert_eq!(cache.offset(), 1);
cache.reset();
}
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.6.0"
version = "0.7.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.6.0" }
candle-nn = { path = "../candle-nn", version = "0.6.0" }
candle = { path = "../candle-core", package = "candle-core", version = "0.7.1" }
candle-nn = { path = "../candle-nn", version = "0.7.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -66,6 +66,18 @@ impl Attr for GraphProto {
}
}
impl AttrOwned for Vec<String> {
const TYPE: AttributeType = AttributeType::Strings;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
let mut ret = vec![];
for bytes in attr.strings.iter() {
let s = String::from_utf8(bytes.clone()).map_err(candle::Error::wrap)?;
ret.push(s);
}
Ok(ret)
}
}
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
@ -340,8 +352,15 @@ fn simple_eval_(
"Pow" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
}
"Exp" => {
let xs = get(&node.input[0])?;
@ -610,6 +629,18 @@ fn simple_eval_(
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(xs.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
// tensor directly, but candle does not support tensor indexing at the moment, so
// some workarounds must be done.
@ -1310,6 +1341,233 @@ fn simple_eval_(
.broadcast_add(&c.broadcast_mul(&beta)?)?;
values.insert(node.output[0].clone(), output);
}
"LSTM" => {
let direction = get_attr_opt(node, "direction")?.unwrap_or("forward");
if direction != "forward" {
bail!("LSTM currently only supports direction == \"forward\"");
}
let num_directions = if direction == "bidirectional" { 2 } else { 1 };
let hidden_size: i64 = get_attr(node, "hidden_size").copied()?;
let input_forget = get_attr_opt(node, "input_forget")?.copied().unwrap_or(0);
if input_forget != 0 {
bail!("LSTM currently only supports input_forget == 0");
}
let activations_default = vec![
"Sigmoid".to_string(),
"Tanh".to_string(),
"Tanh".to_string(),
];
let activations = get_attr_opt_owned::<Vec<String>>(node, "activations")?
.unwrap_or(activations_default.clone());
if activations != activations_default {
bail!("LSTM currently only supports default activations ({activations_default:?})");
}
// activation_alpha and activation_beta don't apply to (Sigmoid, Tanh, Tanh) so ignoring them is okay
if get_attr_opt::<f32>(node, "clip")?.is_some() {
bail!("LSTM does not currently support clip attribute");
}
// The shape format of inputs X, initial_h and outputs Y, Y_h.
// If 0, the following shapes are expected:
// X.shape = [seq_length, batch_size, input_size],
// Y.shape = [seq_length, num_directions, batch_size, hidden_size],
// initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size].
// If 1, the following shapes are expected:
// X.shape = [batch_size, seq_length, input_size],
// Y.shape = [batch_size, seq_length, num_directions, hidden_size],
// initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].
let layout = get_attr_opt(node, "layout")?.copied().unwrap_or(0);
if layout != 0 {
bail!("LSTM currently only supports layout == 0");
}
// The input sequences packed (and potentially padded) into one 3-D tensor
// with the shape of `[seq_length, batch_size, input_size]`.
let x = get(&node.input[0])?;
// XXX: depends on layout
let (seq_length, batch_size, input_size) = x.dims3()?;
// The weight tensor for the gates.
// Concatenation of `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0.
// The tensor has shape `[num_directions, 4*hidden_size, input_size]`.
let w = get(&node.input[1])?;
// The recurrence weight tensor.
// Concatenation of `R[iofc]` and `RB[iofc]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
// Optional: If not specified - assumed to be 0.
let b_default: Tensor;
let b = match get_opt(3) {
Some(n) => n?,
None => {
b_default = Tensor::zeros(
(num_directions, 8 * hidden_size as usize),
DType::F32,
x.device(),
)?;
&b_default
}
};
// Optional tensor specifying lengths of the sequences in a batch.
// If not specified - assumed all sequences in the batch to have length `seq_length`.
// It has shape `[batch_size]`.
let seq_lens_default: Tensor;
let seq_lens = match get_opt(4) {
Some(n) => n?,
None => {
seq_lens_default =
Tensor::full(seq_length as i64, (batch_size,), x.device())?;
&seq_lens_default
}
};
let seq_lens_is_default =
(seq_lens.to_vec1::<i64>()?.iter()).all(|e| *e as usize == seq_length);
if !seq_lens_is_default {
bail!("LSTM currently only supports default value of seq_lens");
}
// Optional initial value of the hidden. If not specified - assumed to be 0.
// It has shape `[num_directions, batch_size, hidden_size]`.
let initial_h_default: Tensor;
let initial_h = match get_opt(5) {
Some(n) => n?,
_ => {
initial_h_default = Tensor::zeros(
(num_directions, batch_size, hidden_size as usize),
DType::F32,
x.device(),
)?;
&initial_h_default
}
};
// Optional initial value of the cell.
// If not specified - assumed to be 0.
// It has shape `[num_directions, batch_size, hidden_size]`.
let initial_c_default: Tensor;
let initial_c = match node.input.get(6) {
Some(n) if !n.is_empty() => get(n)?,
_ => {
initial_c_default = Tensor::zeros(
(num_directions, batch_size, hidden_size as usize),
DType::F32,
x.device(),
)?;
&initial_c_default
}
};
// The weight tensor for peepholes.
// Concatenation of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0.
// It has shape `[num_directions, 3*hidde_size]`. Optional: If not specified - assumed to be 0.
let p_default = Tensor::zeros(
(num_directions, 3 * hidden_size as usize),
DType::F32,
x.device(),
)?;
let p = get_opt(7).unwrap_or(Ok(&p_default))?;
let p_is_zeros = (p.to_vec2::<f32>()?.iter()).all(|v| v.iter().all(|e| *e == 0.0));
if !p_is_zeros {
bail!(
"LSTM currently only supports default value of p (a Tensor of all zeroes)"
);
}
// these all have [num_directions, ...] shapes
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
let c = initial_c.get(0)?;
let h = initial_h.get(0)?;
// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
let w = w.index_select(&idx_ifco, 0)?;
let r = r.index_select(&idx_ifco, 0)?;
let wb = wb.index_select(&idx_ifco, 0)?;
let rb = rb.index_select(&idx_ifco, 0)?;
let vmap = candle_nn::VarMap::new();
vmap.data().lock().unwrap().extend([
("weight_ih_l0".to_string(), candle::Var::from_tensor(&w)?),
("weight_hh_l0".to_string(), candle::Var::from_tensor(&r)?),
("bias_ih_l0".to_string(), candle::Var::from_tensor(&wb)?),
("bias_hh_l0".to_string(), candle::Var::from_tensor(&rb)?),
]);
use candle_nn::rnn::RNN as _;
let lstm = candle_nn::rnn::lstm(
input_size,
hidden_size as usize,
candle_nn::rnn::LSTMConfig::default(),
candle_nn::VarBuilder::from_varmap(&vmap, w.dtype(), w.device()),
)?;
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
};
for t in 0..seq_length {
let x = x.get(t)?;
lstm_state = lstm.step(&x, &lstm_state)?;
if let Some(h_acc) = &mut h_acc {
h_acc.push(lstm_state.clone());
}
}
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.get(0) {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((
seq_length,
num_directions,
batch_size,
hidden_size as usize,
))?;
values.insert(name.clone(), h_acc);
}
if let Some(name) = node.output.get(1) {
values.insert(
name.clone(),
lstm_state.h().reshape((
num_directions,
batch_size,
hidden_size as usize,
))?,
);
}
if let Some(name) = node.output.get(2) {
values.insert(
name.clone(),
lstm_state.c().reshape((
num_directions,
batch_size,
hidden_size as usize,
))?,
);
}
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -6,11 +6,13 @@ extern crate accelerate_src;
use candle::test_utils::to_vec2_round;
use candle::{DType, Device, NdArray, Result, Tensor};
use candle_onnx::eval::Value;
use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use candle_onnx::simple_eval;
use std::collections::HashMap;
const INPUT_X: &str = "x";
@ -3514,3 +3516,467 @@ fn test_slice() -> Result<()> {
Ok(())
}
#[test]
fn test_lstm() -> Result<()> {
// values generated from pytorch, so at least it's close enough to what pytorch does
/*
#!/usr/bin/env python3
# torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)
import torch
rand_gen = torch.Generator()
rand_gen.manual_seed(1)
input_size = 3
hidden_size = 5
batch_size = 1
sequence_length = 4
number_directions = 1
rnn = torch.nn.LSTM(input_size,hidden_size)
weight_ih_l0 = torch.randn(rnn.weight_ih_l0.shape, generator=rand_gen)
weight_hh_l0 = torch.randn(rnn.weight_hh_l0.shape, generator=rand_gen)
bias_ih_l0 = torch.randn(rnn.bias_ih_l0.shape, generator=rand_gen)
bias_hh_l0 = torch.randn(rnn.bias_hh_l0.shape, generator=rand_gen)
rnn.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
rnn.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
rnn.bias_ih_l0 = torch.nn.Parameter(bias_ih_l0)
rnn.bias_hh_l0 = torch.nn.Parameter(bias_hh_l0)
input = torch.randn(sequence_length, batch_size, input_size, generator=rand_gen)
h0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
c0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
output, (hn, cn) = rnn(input, (h0, c0))
def fmt_tensor(t):
return "Tensor::from_vec::<_, f32>(vec!"+ str(t.flatten().tolist()) + ", (" + "".join([str(n)+"," for n in t.shape])+"), &Device::Cpu)?"
print("let input_size = ", input_size, ";")
print("let hidden_size = ", hidden_size, ";")
print("let batch_size = ", batch_size, ";")
print("let sequence_length = ", sequence_length, ";")
print("let number_directions = ", number_directions, ";")
print("let weight_ih_l0 = ", fmt_tensor(rnn.weight_ih_l0), ";")
print("let weight_hh_l0 = ", fmt_tensor(rnn.weight_hh_l0), ";")
print("let bias_ih_l0 = ", fmt_tensor(rnn.bias_ih_l0), ";")
print("let bias_hh_l0 = ", fmt_tensor(rnn.bias_hh_l0), ";")
print("let input = ", fmt_tensor(input), ";")
print("let h0 = ", fmt_tensor(h0), ";")
print("let c0 = ", fmt_tensor(c0), ";")
print("let output = ", fmt_tensor(output), ";")
print("let hn = ", fmt_tensor(hn), ";")
print("let cn = ", fmt_tensor(cn), ";")
*/
let input_size = 3;
let hidden_size = 5;
let batch_size = 1;
let sequence_length = 4;
let number_directions = 1;
let weight_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-1.5255959033966064,
-0.7502318024635315,
-0.6539809107780457,
-1.6094847917556763,
-0.1001671776175499,
-0.6091889142990112,
-0.9797722697257996,
-1.6090962886810303,
-0.7121446132659912,
0.30372199416160583,
-0.777314305305481,
-0.25145524740219116,
-0.22227048873901367,
1.6871134042739868,
0.22842517495155334,
0.46763551235198975,
-0.6969724297523499,
-1.1607614755630493,
0.6995424032211304,
0.1990816295146942,
0.8656923770904541,
0.2444038987159729,
-0.6629113554954529,
0.8073082566261292,
1.1016806364059448,
-0.1759360432624817,
-2.2455577850341797,
-1.4464579820632935,
0.0611552819609642,
-0.6177444458007812,
-0.7980698347091675,
-0.13162320852279663,
1.8793457746505737,
-0.07213178277015686,
0.15777060389518738,
-0.7734549045562744,
0.1990565061569214,
0.04570277780294418,
0.15295691788196564,
-0.47567880153656006,
-0.11101982742547989,
0.2927352488040924,
-0.1578451544046402,
-0.028787139803171158,
0.4532545804977417,
1.1421611309051514,
0.2486107051372528,
-1.7754007577896118,
-0.025502461940050125,
-1.023330569267273,
-0.5961851477622986,
-1.0055307149887085,
0.42854228615760803,
1.4760777950286865,
-1.7868678569793701,
1.610317587852478,
-0.703956663608551,
-0.18526579439640045,
-0.9962350726127625,
-0.8312552571296692,
],
(20, 3),
&Device::Cpu,
)?;
let weight_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
0.4099724292755127,
0.4084506630897522,
0.25786539912223816,
1.095021367073059,
-0.5064865946769714,
0.09977540373802185,
-0.653973400592804,
0.731693685054779,
-1.456732988357544,
1.6089353561401367,
0.09376997500658035,
-1.2597490549087524,
0.25463348627090454,
-0.5019572973251343,
-1.041200041770935,
0.7322672009468079,
1.3075355291366577,
-1.1627987623214722,
0.11963611096143723,
-0.1631353348493576,
0.6614453196525574,
1.1899205446243286,
0.8165339231491089,
-0.9135236144065857,
-0.3538065254688263,
0.7639270424842834,
-0.5889506936073303,
-0.7635973691940308,
1.3352056741714478,
0.6042736172676086,
-0.10344208031892776,
-0.15121692419052124,
1.2465683221817017,
0.505721390247345,
0.9505112171173096,
1.2966482639312744,
0.873796284198761,
-0.5602594017982483,
1.2857844829559326,
0.8168238401412964,
-1.464799404144287,
-1.2629283666610718,
1.122018814086914,
1.5663341283798218,
2.558138370513916,
-0.23336388170719147,
-0.013472129590809345,
1.8606348037719727,
1.549620509147644,
0.34762924909591675,
0.09300802648067474,
0.6147403120994568,
0.7123645544052124,
-1.7765072584152222,
0.3538645803928375,
1.1996132135391235,
-0.7122589349746704,
-0.620034396648407,
-0.22813494503498077,
-0.7892746329307556,
-1.6111117601394653,
-1.8716129064559937,
0.5430836081504822,
0.6606786251068115,
0.270527720451355,
0.5596919655799866,
-0.31839630007743835,
1.5117206573486328,
-1.363267183303833,
-0.9832196235656738,
1.5112667083740234,
0.6418707370758057,
-0.7474458813667297,
-0.923438549041748,
0.5733984112739563,
-0.10929951071739197,
0.5181121230125427,
0.10653535276651382,
0.26924076676368713,
1.3247679471969604,
0.037456899881362915,
-0.6378393173217773,
-0.8147554397583008,
-0.6895065307617188,
0.8436542749404907,
1.1657012701034546,
0.5269321799278259,
1.6192532777786255,
-0.963976263999939,
0.14152038097381592,
-0.1636609584093094,
-0.3582225739955902,
1.7222793102264404,
-0.3035756051540375,
0.23887419700622559,
1.3440011739730835,
0.1032256931066513,
1.1003541946411133,
-0.3416801989078522,
0.947338879108429,
],
(20, 5),
&Device::Cpu,
)?;
let bias_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-0.568515956401825,
0.8375961780548096,
1.783660650253296,
-0.1954246610403061,
0.235193133354187,
1.9142433404922485,
1.8364111185073853,
1.324532389640808,
-0.07051458209753036,
0.34697940945625305,
-0.653679609298706,
1.5586202144622803,
0.2185661494731903,
-0.5743072628974915,
1.4571250677108765,
1.7709556818008423,
-2.0172998905181885,
0.42350319027900696,
0.5730220079421997,
-1.7962429523468018,
],
(20,),
&Device::Cpu,
)?;
let bias_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
1.2470403909683228,
1.2738511562347412,
0.3909492492675781,
0.387210488319397,
0.14440394937992096,
0.7771684527397156,
-2.3381125926971436,
-0.829120397567749,
1.1661391258239746,
1.4786574840545654,
0.26760873198509216,
0.7561198472976685,
-0.5873361229896545,
-2.061920642852783,
0.4304734766483307,
0.3376566171646118,
-0.3437853455543518,
-0.6172260642051697,
1.2529692649841309,
-0.05141742154955864,
],
(20,),
&Device::Cpu,
)?;
let input = Tensor::from_vec::<_, f32>(
vec![
0.6472128033638,
-0.04116716980934143,
-0.17749308049678802,
-0.500039279460907,
0.8672749400138855,
-0.27319222688674927,
-0.4607681334018707,
-0.0990937128663063,
0.47284480929374695,
1.0049484968185425,
-0.2871420383453369,
-1.1618621349334717,
],
(4, 1, 3),
&Device::Cpu,
)?;
let h0 = Tensor::from_vec::<_, f32>(
vec![
0.02758178487420082,
0.5652382373809814,
-0.011487378738820553,
0.6706400513648987,
-0.4929250478744507,
],
(1, 1, 5),
&Device::Cpu,
)?;
let c0 = Tensor::from_vec::<_, f32>(
vec![
1.505028486251831,
-2.32635498046875,
1.6168899536132812,
-0.9026237726211548,
0.17366823554039001,
],
(1, 1, 5),
&Device::Cpu,
)?;
let output = Tensor::from_vec::<_, f32>(
vec![
0.5956016778945923,
-0.01723279245197773,
0.11035571992397308,
-0.49323174357414246,
0.047632161527872086,
0.6358451843261719,
0.040328118950128555,
-0.3788611590862274,
-0.7464339733123779,
0.20080909132957458,
0.5840265154838562,
0.1453288197517395,
-0.7345298528671265,
-0.5214304327964783,
0.21903817355632782,
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
],
(4, 1, 5),
&Device::Cpu,
)?;
let hn = Tensor::from_vec::<_, f32>(
vec![
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
],
(1, 1, 5),
&Device::Cpu,
)?;
let cn = Tensor::from_vec::<_, f32>(
vec![
0.9630558490753174,
1.0033069849014282,
-1.754899024963379,
-1.5967122316360474,
0.8252924680709839,
],
(1, 1, 5),
&Device::Cpu,
)?;
// end of generated values
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "LSTM".to_string(),
name: "LSTM_test".to_string(),
attribute: vec![AttributeProto {
name: "hidden_size".to_string(),
r#type: AttributeType::Int.into(),
i: hidden_size as i64,
..AttributeProto::default()
}],
input: vec![
"input".to_string(),
"w".to_string(),
"r".to_string(),
"b".to_string(), // b
"".to_string(), // seq_lens
"h".to_string(),
"c".to_string(),
],
output: vec!["output".to_string(), "hn".to_string(), "cn".to_string()],
..NodeProto::default()
}],
input: ["input", "w", "r", "b", "h", "c"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
output: ["output", "hn", "cn"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
..GraphProto::default()
}));
// pytorch stores weight and bias as [ifco] but we want it as [iofc]
// so we need to re-arrange the tensors a bit
let idx_iofc = {
let stride = hidden_size as i64;
let dev = weight_ih_l0.device();
let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?;
let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?;
let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?;
let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?;
Tensor::cat(&[&idx_i, &idx_o, &idx_f, &idx_g], 0)?
};
let w = weight_ih_l0.index_select(&idx_iofc, 0)?;
let w = w.reshape((number_directions, 4 * hidden_size, input_size))?;
let r = weight_hh_l0.index_select(&idx_iofc, 0)?;
let r = r.reshape((number_directions, 4 * hidden_size, hidden_size))?;
let wb = bias_ih_l0.index_select(&idx_iofc, 0)?;
let rb = bias_hh_l0.index_select(&idx_iofc, 0)?;
let b = Tensor::cat(&[wb, rb], 0)?.reshape((number_directions, 8 * hidden_size))?;
let output = output.reshape((sequence_length, number_directions, batch_size, hidden_size))?;
let result = simple_eval(
&model,
HashMap::from_iter([
("input".to_string(), input),
("w".to_string(), w),
("r".to_string(), r),
("b".to_string(), b),
("h".to_string(), h0),
("c".to_string(), c0),
]),
)?;
let actual_output = result.get("output").unwrap();
assert_eq!(output.dims(), actual_output.dims());
let actual_hn = result.get("hn").unwrap();
assert_eq!(hn.dims(), actual_hn.dims());
let actual_cn = result.get("cn").unwrap();
assert_eq!(cn.dims(), actual_cn.dims());
let diff_close_enough = |a: &Tensor, b| -> Result<_> {
let diffs = a.sub(b)?.flatten_all()?.to_vec1::<f32>()?;
Ok(diffs.iter().all(|f| f.abs() < 0.0001))
};
assert!(
diff_close_enough(&output, &actual_output)?,
"output did not match expected\n{actual_output}\n{output}",
);
assert!(
diff_close_enough(&hn, &actual_hn)?,
"hn did not match expected\n{actual_hn}\n{hn}",
);
assert!(
diff_close_enough(&cn, &actual_cn)?,
"cn did not match expected\n{actual_cn}\n{cn}",
);
Ok(())
}

View File

@ -0,0 +1,589 @@
//! Based from the Stanford Hazy Research group.
//!
//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024
//! <https://arxiv.org/abs/2402.18668>
//! Original code:
//! https://github.com/HazyResearch/based
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{
conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig,
Func, Linear, RmsNorm, VarBuilder,
};
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LinearAttentionFeatureMapConfig {
input_dim: usize,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LinearAttentionConfig {
num_heads: usize,
feature_dim: usize,
feature_map: LinearAttentionFeatureMapConfig,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct SlidingWindowAttentionConfig {
num_heads: usize,
window_size: usize,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
vocab_size: usize,
#[serde(rename = "n_embd")]
hidden_size: usize,
#[serde(rename = "n_inner")]
intermediate_size: usize,
#[serde(rename = "n_layer")]
num_hidden_layers: usize,
#[serde(rename = "n_head")]
num_attention_heads: usize,
layer_norm_epsilon: f64,
#[serde(default = "default_rope", rename = "rotary_emb_base")]
rope_theta: f64,
alt_mixer_layers: Vec<usize>,
alt_mixer_2_layers: Vec<usize>,
#[serde(rename = "alt_mixer")]
la: LinearAttentionConfig,
#[serde(rename = "alt_mixer_2")]
swa: SlidingWindowAttentionConfig,
}
fn default_rope() -> f64 {
10_000.0
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
fc1: Linear,
fc2: Linear,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
Ok(Self { fc1, fc2 })
}
}
// Swiglu implementation.
// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs
fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, D::Minus1)?;
&xs[1].silu()? * &xs[0]
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.fc1)?;
let xs = swiglu(&xs)?;
let xs = xs.apply(&self.fc2)?;
Ok(xs)
}
}
// A gated convolutional block.
#[derive(Debug, Clone)]
struct BasedConv {
in_proj: Linear,
out_proj: Linear,
conv: Conv1d,
state: Tensor,
}
impl BasedConv {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dim = cfg.hidden_size * 2;
let conv1d_cfg = Conv1dConfig {
groups: dim,
padding: 2,
..Default::default()
};
let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?;
let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?;
let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?;
let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?;
Ok(Self {
in_proj,
out_proj,
conv,
state,
})
}
fn step(&mut self, xs: &Tensor) -> Result<Tensor> {
self.state = self.state.roll(-1, D::Minus1)?;
let (_, _, l) = self.state.dims3()?;
self.state = self.state.narrow(D::Minus1, 0, l - 1)?;
self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?;
let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)?
.sum_keepdim(0)?
.sum(D::Minus1)?;
let xs = xs.unsqueeze(1)?;
Ok(xs)
}
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let xs = xs.apply(&self.in_proj)?;
let us = xs.chunk(2, D::Minus1)?;
let (_b, l, _d) = us[0].dims3()?;
let u_conv = if seqlen_offset > 0 {
self.step(&us[0])?
} else {
let k = std::cmp::min(3, l);
self.state = self.state.narrow(D::Minus1, 0, 3 - k)?;
let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?;
self.state = Tensor::cat(&[&self.state, &xs], 2)?;
us[0]
.transpose(1, 2)?
.apply(&self.conv)?
.narrow(D::Minus1, 0, l)?
.transpose(1, 2)?
};
let u_conv = u_conv.silu()?;
let v = u_conv.broadcast_mul(&us[1])?;
let xs = v.apply(&self.out_proj)?;
Ok(xs)
}
}
// Linear attention approximating softmax using second order Taylor polynomials.
#[derive(Debug, Clone)]
struct LinearAttention {
proj_q: Linear,
proj_k: Linear,
proj_v: Linear,
out_proj: Linear,
feature_dim: usize,
num_heads: usize,
input_dim: usize,
k_state: Tensor,
kv_state: Tensor,
}
impl LinearAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let input_dim = cfg.la.feature_map.input_dim;
let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?;
let proj_k = linear_no_bias(
cfg.hidden_size,
cfg.la.num_heads * cfg.la.feature_dim,
vb.pp("proj_k"),
)?;
let proj_q = linear_no_bias(
cfg.hidden_size,
cfg.la.num_heads * cfg.la.feature_dim,
vb.pp("proj_q"),
)?;
let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?;
let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1;
let k_state = Tensor::zeros(
(1, cfg.la.num_heads, 1, 1, expanded_size),
vb.dtype(),
vb.device(),
)?;
let kv_state = Tensor::zeros(
(1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size),
vb.dtype(),
vb.device(),
)?;
Ok(Self {
proj_q,
proj_k,
proj_v,
out_proj,
feature_dim: cfg.la.feature_dim,
num_heads: cfg.la.num_heads,
input_dim,
k_state,
kv_state,
})
}
fn taylor_expansion(&self) -> Result<Func<'static>> {
let r2 = std::f64::consts::SQRT_2;
let rd = (self.input_dim as f64).sqrt();
let rrd = rd.sqrt();
Ok(Func::new(move |xs| {
let dims = xs.dims();
let mut d = dims.to_vec();
if let Some(last) = d.last_mut() {
*last = 1;
};
let x = xs
.unsqueeze(D::Minus1)?
.broadcast_mul(&xs.unsqueeze(D::Minus2)?)?;
let x = (x.flatten_from(D::Minus2)? / r2)?;
let o = Tensor::ones(d, xs.dtype(), xs.device())?;
let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?;
Ok(x)
}))
}
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let eps = 1e-12;
let feature_map = self.taylor_expansion()?;
let (b, l, d) = xs.dims3()?;
let q = xs.apply(&self.proj_q)?;
let k = xs.apply(&self.proj_k)?;
let v = xs.apply(&self.proj_v)?;
let q = q
.reshape((b, l, self.num_heads, self.feature_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b, l, self.num_heads, self.feature_dim))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b, l, self.num_heads, d / self.num_heads))?
.transpose(1, 2)?
.contiguous()?;
let q = feature_map.forward(&q)?;
let k = feature_map.forward(&k)?;
let y = if seqlen_offset > 0 {
let (_b, _h, l, _d) = k.dims4()?;
let q = q.unsqueeze(D::Minus2)?;
let k = k.unsqueeze(D::Minus2)?;
let v = v.unsqueeze(D::Minus1)?;
let kn = k.narrow(D::Minus1, l - 1, 1)?;
let vn = v.narrow(D::Minus1, l - 1, 1)?;
self.k_state = self.k_state.broadcast_add(&kn)?;
self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?;
let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?;
let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?;
num.broadcast_div(&den)?
} else {
self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?;
self.kv_state = k
.transpose(2, 3)?
.matmul(&v)?
.transpose(2, 3)?
.unsqueeze(2)?;
let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?;
let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?;
let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?;
aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)?
};
let (b, h, l, d) = y.dims4()?;
let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?;
let y = self.out_proj.forward(&y)?;
Ok(y)
}
}
// Rotary embeddings used in local attention.
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = 2048; // Hardcoded, missing from config.
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
// Local attention using a small sliding window.
#[derive(Debug, Clone)]
struct SlidingWindowAttention {
wqkv: Linear,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl SlidingWindowAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let num_heads = cfg.swa.num_heads;
let head_dim = hidden_size / num_heads;
let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?;
let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
Ok(Self {
wqkv,
out_proj,
hidden_size,
num_heads,
head_dim,
rotary_emb,
kv_cache: None,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?;
let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?;
let q = qkv.i((.., .., 0))?;
let k = qkv.i((.., .., 1))?;
let v = qkv.i((.., .., 2))?;
let q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let (q, k) = self
.rotary_emb
.apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &k], 2)?;
let v = Tensor::cat(&[prev_v, &v], 2)?;
(k, v)
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
let out = attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.out_proj)?;
Ok(out)
}
}
// The model layers use three types of mixers.
#[derive(Debug, Clone)]
enum SequenceMixer {
Based(BasedConv),
Linear(LinearAttention),
Sliding(SlidingWindowAttention),
}
impl SequenceMixer {
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
pos: usize,
) -> Result<Tensor> {
match self {
Self::Based(b) => b.forward(xs, pos),
Self::Linear(b) => b.forward(xs, pos),
Self::Sliding(b) => b.forward(xs, attention_mask, pos),
}
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
mlp: MLP,
norm1: RmsNorm,
norm2: RmsNorm,
mixer: SequenceMixer,
}
impl DecoderLayer {
fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?;
let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?;
let l_attn = cfg.alt_mixer_layers.contains(&layer_idx);
let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx);
let mixer = if l_attn {
SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?)
} else if sw_attn {
SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?)
} else {
SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?)
};
Ok(Self {
mlp,
norm1,
norm2,
mixer,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.norm1.forward(xs)?;
let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?;
residual + xs
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: super::with_tracing::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8;
let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?;
let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?;
let vb_m = vb.pp("transformer");
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.swa.window_size,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let sliding_window = self.sliding_window / 2;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}

View File

@ -419,7 +419,7 @@ struct BertEncoder {
impl BertEncoder {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
@ -454,8 +454,8 @@ impl BertModel {
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
) {
(embeddings, encoder)
} else {
@ -501,5 +501,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
};
let attention_mask = attention_mask.to_dtype(dtype)?;
// torch.finfo(dtype).min
(attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
(attention_mask.ones_like()? - &attention_mask)?
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
}

View File

@ -298,7 +298,7 @@ impl GPTBigCode {
let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
let blocks = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
.map(|i| Block::load(vb_t.pp(format!("h.{i}")), &cfg))
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;

View File

@ -0,0 +1,376 @@
/// Adapted from https://github.com/descriptinc/descript-audio-codec
use crate::models::encodec;
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder};
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub num_codebooks: usize,
pub model_bitrate: u32,
pub codebook_size: usize,
pub latent_dim: usize,
pub frame_rate: u32,
pub sampling_rate: u32,
}
#[derive(Debug, Clone)]
pub struct Snake1d {
alpha: Tensor,
}
impl Snake1d {
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
let alpha = vb.get((1, channels, 1), "alpha")?;
Ok(Self { alpha })
}
}
impl candle::Module for Snake1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs_shape = xs.shape();
let xs = xs.flatten_from(2)?;
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
let sin = (&sin * &sin)?;
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
}
}
#[derive(Debug, Clone)]
pub struct ResidualUnit {
snake1: Snake1d,
conv1: Conv1d,
snake2: Snake1d,
conv2: Conv1d,
}
impl ResidualUnit {
pub fn new(dim: usize, dilation: usize, vb: VarBuilder) -> Result<Self> {
let pad = ((7 - 1) * dilation) / 2;
let vb = vb.pp("block");
let snake1 = Snake1d::new(dim, vb.pp(0))?;
let cfg1 = Conv1dConfig {
dilation,
padding: pad,
..Default::default()
};
let conv1 = encodec::conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
let snake2 = Snake1d::new(dim, vb.pp(2))?;
let conv2 = encodec::conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
Ok(Self {
snake1,
conv1,
snake2,
conv2,
})
}
}
impl candle::Module for ResidualUnit {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = xs
.apply(&self.snake1)?
.apply(&self.conv1)?
.apply(&self.snake2)?
.apply(&self.conv2)?;
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
if pad > 0 {
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
} else {
ys + xs
}
}
}
#[derive(Debug, Clone)]
pub struct EncoderBlock {
res1: ResidualUnit,
res2: ResidualUnit,
res3: ResidualUnit,
snake1: Snake1d,
conv1: Conv1d,
}
impl EncoderBlock {
pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("block");
let res1 = ResidualUnit::new(dim / 2, 1, vb.pp(0))?;
let res2 = ResidualUnit::new(dim / 2, 3, vb.pp(1))?;
let res3 = ResidualUnit::new(dim / 2, 9, vb.pp(2))?;
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
let cfg1 = Conv1dConfig {
stride,
padding: (stride + 1) / 2,
..Default::default()
};
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
Ok(Self {
res1,
res2,
res3,
snake1,
conv1,
})
}
}
impl candle::Module for EncoderBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.res1)?
.apply(&self.res2)?
.apply(&self.res3)?
.apply(&self.snake1)?
.apply(&self.conv1)
}
}
#[derive(Debug, Clone)]
pub struct Encoder {
conv1: Conv1d,
blocks: Vec<EncoderBlock>,
snake1: Snake1d,
conv2: Conv1d,
}
impl candle::Module for Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.conv1)?;
for block in self.blocks.iter() {
xs = xs.apply(block)?
}
xs.apply(&self.snake1)?.apply(&self.conv2)
}
}
impl Encoder {
pub fn new(
mut d_model: usize,
strides: &[usize],
d_latent: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("block");
let cfg1 = Conv1dConfig {
padding: 3,
..Default::default()
};
let conv1 = encodec::conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(0))?;
let mut blocks = Vec::with_capacity(strides.len());
for (block_idx, stride) in strides.iter().enumerate() {
d_model *= 2;
let block = EncoderBlock::new(d_model, *stride, vb.pp(block_idx + 1))?;
blocks.push(block)
}
let snake1 = Snake1d::new(d_model, vb.pp(strides.len() + 1))?;
let cfg2 = Conv1dConfig {
padding: 1,
..Default::default()
};
let conv2 =
encodec::conv1d_weight_norm(d_model, d_latent, 3, cfg2, vb.pp(strides.len() + 2))?;
Ok(Self {
conv1,
blocks,
snake1,
conv2,
})
}
}
#[derive(Debug, Clone)]
pub struct DecoderBlock {
snake1: Snake1d,
conv_tr1: ConvTranspose1d,
res1: ResidualUnit,
res2: ResidualUnit,
res3: ResidualUnit,
}
impl DecoderBlock {
pub fn new(in_dim: usize, out_dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("block");
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
let cfg = ConvTranspose1dConfig {
stride,
padding: (stride + 1) / 2,
..Default::default()
};
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
in_dim,
out_dim,
2 * stride,
true,
cfg,
vb.pp(1),
)?;
let res1 = ResidualUnit::new(out_dim, 1, vb.pp(2))?;
let res2 = ResidualUnit::new(out_dim, 3, vb.pp(3))?;
let res3 = ResidualUnit::new(out_dim, 9, vb.pp(4))?;
Ok(Self {
snake1,
conv_tr1,
res1,
res2,
res3,
})
}
}
impl candle_nn::Module for DecoderBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.snake1)?
.apply(&self.conv_tr1)?
.apply(&self.res1)?
.apply(&self.res2)?
.apply(&self.res3)
}
}
#[derive(Debug, Clone)]
pub struct Decoder {
conv1: Conv1d,
blocks: Vec<DecoderBlock>,
snake1: Snake1d,
conv2: Conv1d,
}
impl Decoder {
pub fn new(
in_c: usize,
mut channels: usize,
rates: &[usize],
d_out: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("model");
let cfg1 = Conv1dConfig {
padding: 3,
..Default::default()
};
let conv1 = encodec::conv1d_weight_norm(in_c, channels, 7, cfg1, vb.pp(0))?;
let mut blocks = Vec::with_capacity(rates.len());
for (idx, stride) in rates.iter().enumerate() {
let block = DecoderBlock::new(channels, channels / 2, *stride, vb.pp(idx + 1))?;
channels /= 2;
blocks.push(block)
}
let snake1 = Snake1d::new(channels, vb.pp(rates.len() + 1))?;
let conv2 = encodec::conv1d_weight_norm(channels, d_out, 7, cfg1, vb.pp(rates.len() + 2))?;
Ok(Self {
conv1,
blocks,
snake1,
conv2,
})
}
}
impl candle::Module for Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.conv1)?;
for block in self.blocks.iter() {
xs = xs.apply(block)?
}
xs.apply(&self.snake1)?.apply(&self.conv2)
}
}
#[allow(unused)]
#[derive(Clone, Debug)]
pub struct VectorQuantizer {
in_proj: Conv1d,
out_proj: Conv1d,
codebook: candle_nn::Embedding,
}
impl VectorQuantizer {
pub fn new(in_dim: usize, cb_size: usize, cb_dim: usize, vb: VarBuilder) -> Result<Self> {
let in_proj =
encodec::conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
let out_proj =
encodec::conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
Ok(Self {
in_proj,
out_proj,
codebook,
})
}
pub fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
embed_id.apply(&self.codebook)
}
pub fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
self.embed_code(embed_id)?.transpose(1, 2)
}
}
#[derive(Clone, Debug)]
pub struct ResidualVectorQuantizer {
quantizers: Vec<VectorQuantizer>,
}
impl ResidualVectorQuantizer {
pub fn new(
input_dim: usize,
n_codebooks: usize,
cb_size: usize,
cb_dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = &vb.pp("quantizers");
let quantizers = (0..n_codebooks)
.map(|i| VectorQuantizer::new(input_dim, cb_size, cb_dim, vb.pp(i)))
.collect::<Result<Vec<_>>>()?;
Ok(Self { quantizers })
}
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
let mut sum = None;
for (idx, quantizer) in self.quantizers.iter().enumerate() {
let z_p_i = quantizer.decode_code(&codes.i((.., idx))?)?;
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
let s = match sum {
None => z_q_i,
Some(s) => (s + z_q_i)?,
};
sum = Some(s)
}
match sum {
Some(s) => Ok(s),
None => candle::bail!("empty codebooks"),
}
}
}
#[derive(Debug, Clone)]
pub struct Model {
pub encoder: Encoder,
pub quantizer: ResidualVectorQuantizer,
pub decoder: Decoder,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("model");
let encoder = Encoder::new(64, &[2, 4, 8, 8], cfg.latent_dim, vb.pp("encoder"))?;
let quantizer = ResidualVectorQuantizer::new(
cfg.latent_dim,
cfg.num_codebooks,
cfg.codebook_size,
8,
vb.pp("quantizer"),
)?;
let decoder = Decoder::new(cfg.latent_dim, 1536, &[8, 8, 4, 2], 1, vb.pp("decoder"))?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
pub fn decode_codes(&self, audio_codes: &Tensor) -> Result<Tensor> {
let audio_values = self.quantizer.from_codes(audio_codes)?;
audio_values.apply(&self.decoder)
}
}

View File

@ -275,7 +275,7 @@ struct Transformer {
impl Transformer {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.n_layers)
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
.map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(Transformer { layers, span })
@ -311,8 +311,8 @@ impl DistilBertModel {
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
Transformer::load(vb.pp(format!("{model_type}.transformer")), config),
) {
(embeddings, encoder)
} else {

View File

@ -136,7 +136,7 @@ pub fn conv1d_weight_norm(
Ok(Conv1d::new(weight, Some(bias), config))
}
fn conv_transpose1d_weight_norm(
pub fn conv_transpose1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,

View File

@ -448,7 +448,7 @@ impl Falcon {
vb.pp("transformer.word_embeddings"),
)?;
let blocks = (0..cfg.num_hidden_layers)
.map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
.map(|i| FalconDecoderLayer::load(vb.pp(format!("transformer.h.{i}")), &cfg))
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(
cfg.hidden_size,

View File

@ -0,0 +1,512 @@
//! FastViT inference implementation based on timm
//!
//! See "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"
//! https://arxiv.org/pdf/2303.14189
//!
//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py
use candle::{DType, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
};
#[derive(Clone, Debug)]
pub struct Config {
exp_ratio: usize,
in_channels: usize,
blocks: [usize; 4],
attn: bool,
lkc_use_act: bool,
}
impl Config {
pub fn t8() -> Self {
Self {
exp_ratio: 3,
in_channels: 48,
blocks: [2, 2, 4, 2],
attn: false,
lkc_use_act: false,
}
}
pub fn t12() -> Self {
Self {
exp_ratio: 3,
in_channels: 64,
blocks: [2, 2, 6, 2],
attn: false,
lkc_use_act: false,
}
}
pub fn s12() -> Self {
Self {
exp_ratio: 4,
in_channels: 64,
blocks: [2, 2, 6, 2],
attn: false,
lkc_use_act: false,
}
}
pub fn sa12() -> Self {
Self {
exp_ratio: 4,
in_channels: 64,
blocks: [2, 2, 6, 2],
attn: true,
lkc_use_act: false,
}
}
pub fn sa24() -> Self {
Self {
exp_ratio: 4,
in_channels: 64,
blocks: [4, 4, 12, 4],
attn: true,
lkc_use_act: false,
}
}
pub fn sa36() -> Self {
Self {
exp_ratio: 4,
in_channels: 64,
blocks: [6, 6, 18, 6],
attn: true,
lkc_use_act: false,
}
}
pub fn ma36() -> Self {
Self {
exp_ratio: 4,
in_channels: 76,
blocks: [6, 6, 18, 6],
attn: true,
lkc_use_act: false,
}
}
// configs used by MobileCLIP's image encoder
pub fn mci0() -> Self {
Self {
exp_ratio: 3,
in_channels: 64,
blocks: [2, 6, 10, 2],
attn: true,
lkc_use_act: true,
}
}
pub fn mci1() -> Self {
Self {
exp_ratio: 3,
in_channels: 64,
blocks: [4, 12, 20, 4],
attn: true,
lkc_use_act: true,
}
}
pub fn mci2() -> Self {
Self {
exp_ratio: 3,
in_channels: 80,
blocks: [4, 12, 24, 4],
attn: true,
lkc_use_act: true,
}
}
}
fn conv_norm(
in_channels: usize,
out_channels: usize,
kernel: usize,
stride: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
padding: kernel / 2,
groups: in_channels,
..Default::default()
};
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
let conv = conv.absorb_bn(&bn)?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&conv)?;
Ok(xs)
}))
}
fn conv_mlp(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let conv = conv_norm(dim, dim, 7, 1, vb.pp("conv"))?;
let fc1 = conv2d(dim, dim * exp_ratio, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(dim * exp_ratio, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&conv)?.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
Ok(xs)
}))
}
fn squeeze_and_excitation(
in_channels: usize,
squeeze_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
residual.broadcast_mul(&xs)
}))
}
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
let bias = (beta - mu * &gps)?;
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
Ok((weights, bias))
}
fn mobileone_block(
in_channels: usize,
out_channels: usize,
kernel: usize,
stride: usize,
group_size: usize,
use_act: bool,
vb: VarBuilder,
) -> Result<Func<'static>> {
let groups = if group_size == 0 {
1
} else {
in_channels / group_size
};
let padding = kernel / 2;
let conv2d_cfg = Conv2dConfig {
stride,
groups,
padding,
..Default::default()
};
let mut w = Tensor::zeros(
(out_channels, in_channels / groups, kernel, kernel),
DType::F32,
vb.device(),
)?;
let dim = out_channels;
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn"));
let conv_kxk = conv2d_no_bias(
in_channels,
out_channels,
kernel,
conv2d_cfg,
vb.pp("conv_kxk.0.conv"),
);
if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) {
let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?;
w = (w + wk)?;
b = (b + bk)?;
};
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"));
let conv_scale = conv2d_no_bias(
in_channels,
out_channels,
1,
conv2d_cfg,
vb.pp("conv_scale.conv"),
);
if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) {
let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?;
// pad to 3x3
let ws = ws
.pad_with_zeros(D::Minus1, 1, 1)?
.pad_with_zeros(D::Minus2, 1, 1)?;
w = (w + ws)?;
b = (b + bs)?;
};
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se"));
// read and reparameterize the identity bn into wi and bi
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"));
if let Ok(id_bn) = identity_bn {
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
let id = in_channels / groups;
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
for i in 0..in_channels {
if kernel > 1 {
weights[i * kernel * kernel + 4] = 1.0;
} else {
weights[i * (id + 1)] = 1.0;
}
}
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
let (wi, bi) = fuse_conv_bn(weights, id_bn)?;
w = (w + wi)?;
b = (b + bi)?;
};
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&reparam_conv)?;
if let Ok(f) = &se {
xs = xs.apply(f)?;
}
if use_act {
xs = xs.gelu_erf()?;
};
Ok(xs)
}))
}
fn repmixer(dim: usize, kernel: usize, vb: VarBuilder) -> Result<Func<'static>> {
let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
let norm = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("norm"))?;
let mixer = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("mixer"))?;
Ok(Func::new(move |xs| {
let residual = xs.clone();
let xs = (xs.apply(&mixer)? - xs.apply(&norm)?)?;
let xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
let xs = (xs + residual)?;
Ok(xs)
}))
}
fn repmixer_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
let token_mixer = repmixer(dim, 3, vb.pp("token_mixer"))?;
let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
Ok(Func::new(move |xs| {
let residual = xs.apply(&token_mixer)?;
let mut xs = residual.apply(&mlp)?;
xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
let xs = (xs + residual)?;
Ok(xs)
}))
}
fn positional_encoding(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride: 1,
padding: 3,
groups: dim,
..Default::default()
};
let conv = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("pos_enc"))?;
Ok(Func::new(move |xs| {
let xs = (xs + xs.apply(&conv)?)?;
Ok(xs)
}))
}
fn attention(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?;
let proj = linear(dim, dim, vb.pp("proj"))?;
let head_dim = 32;
let num_heads = dim / head_dim;
let scale = (head_dim as f64).powf(-0.5);
Ok(Func::new(move |xs| {
let xs = xs.clone();
let (b, c, h, w) = xs.dims4()?;
let n = h * w;
let xs = xs.flatten_from(2)?.transpose(D::Minus1, D::Minus2)?;
let qkv = xs
.apply(&qkv)?
.reshape((b, n, 3, num_heads, head_dim))?
.permute((2, 0, 3, 1, 4))?;
let q = qkv.get(0)?;
let k = qkv.get(1)?;
let v = qkv.get(2)?;
let q = (q * scale)?;
let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
let att = softmax(&att, D::Minus1)?;
let xs = att.matmul(&v)?;
let xs = xs.transpose(1, 2)?.reshape((b, n, c))?;
let xs = xs.apply(&proj)?;
let xs = xs.transpose(D::Minus1, D::Minus2)?.reshape((b, c, h, w))?;
Ok(xs)
}))
}
fn attention_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
let gamma1 = vb.get((dim, 1, 1), "layer_scale_1.gamma")?;
let gamma2 = vb.get((dim, 1, 1), "layer_scale_2.gamma")?;
let norm = batch_norm(dim, 1e-5, vb.pp("norm"))?;
let token_mixer = attention(dim, vb.pp("token_mixer"))?;
let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
Ok(Func::new(move |xs| {
let xs = xs.clone();
let xs = (&xs
+ &xs
.apply_t(&norm, false)?
.apply(&token_mixer)?
.broadcast_mul(&gamma1.reshape((1, (), 1, 1))?)?)?;
let xs = (&xs
+ &xs
.apply(&mlp)?
.broadcast_mul(&gamma2.reshape((1, (), 1, 1))?)?)?;
Ok(xs)
}))
}
fn fastvit_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nblocks = cfg.blocks[idx];
let mut blocks = Vec::with_capacity(nblocks);
let dim = cfg.in_channels << idx;
let downsample = fastvit_patch_embed(dim / 2, dim, cfg.lkc_use_act, vb.pp("downsample"));
for block_idx in 0..nblocks {
let block = if cfg.attn && idx == 3 {
attention_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
} else {
repmixer_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
};
blocks.push(block);
}
let pos_emb = positional_encoding(dim, vb.pp("pos_emb"));
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
if let Ok(ds) = &downsample {
xs = xs.apply(ds)?;
}
if let Ok(pos) = &pos_emb {
xs = xs.apply(pos)?;
}
for block in blocks.iter() {
xs = xs.apply(block)?;
}
Ok(xs)
}))
}
fn fastvit_patch_embed(
in_channels: usize,
out_channels: usize,
use_act: bool,
vb: VarBuilder,
) -> Result<Func<'static>> {
let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?;
let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?;
let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se"));
let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?;
Ok(Func::new(move |xs| {
let mut xs = (xs.apply(&lk)? + xs.apply(&sk)?)?;
if let Ok(f) = &se {
xs = xs.apply(f)?;
}
if use_act {
xs = xs.gelu_erf()?;
};
let xs = xs.apply(&mb)?;
Ok(xs)
}))
}
fn fastvit_stem(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
let mb0 = mobileone_block(in_channels, out_channels, 3, 2, 0, true, vb.pp(0))?;
let mb1 = mobileone_block(out_channels, out_channels, 3, 2, 1, true, vb.pp(1))?;
let mb2 = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp(2))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&mb0)?.apply(&mb1)?.apply(&mb2)?;
Ok(xs)
}))
}
// Build a fastvit model for a given configuration.
fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let linear = linear(cfg.in_channels * 16, nclasses, vb.pp("head.fc"))?;
Some(linear)
}
};
let stem = fastvit_stem(3, cfg.in_channels, vb.pp("stem"))?;
let final_conv = mobileone_block(
cfg.in_channels * 8,
cfg.in_channels * 16,
3,
1,
1,
true,
vb.pp("final_conv"),
)?;
let vb = vb.pp("stages");
let stage1 = fastvit_stage(cfg, 0, vb.pp(0))?;
let stage2 = fastvit_stage(cfg, 1, vb.pp(1))?;
let stage3 = fastvit_stage(cfg, 2, vb.pp(2))?;
let stage4 = fastvit_stage(cfg, 3, vb.pp(3))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.apply(&stage4)?
.apply(&final_conv)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),
}
}))
}
pub fn fastvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
fastvit_model(cfg, Some(nclasses), vb)
}
pub fn fastvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
fastvit_model(cfg, None, vb)
}

View File

@ -0,0 +1,449 @@
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
fn default_max_position_embeddings() -> usize {
4096
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
pub hidden_activation: Activation,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub vocab_size: usize,
pub final_logit_softcapping: Option<f64>,
pub attn_logit_softcapping: Option<f64>,
pub query_pre_attn_scalar: usize,
// TODO: Handle the sliding window in the attention mask.
pub sliding_window: Option<usize>,
#[serde(default = "default_max_position_embeddings")]
pub max_position_embeddings: usize,
}
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&(&self.weight + 1.0)?)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_activation,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
attn_logit_softcapping: Option<f64>,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
use_flash_attn: bool,
}
impl Attention {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = cfg.head_dim;
let bias = cfg.attention_bias;
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
attn_logit_softcapping: cfg.attn_logit_softcapping,
rotary_emb,
kv_cache: None,
use_flash_attn,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = query_states.transpose(1, 2)?;
let k = key_states.transpose(1, 2)?;
let v = value_states.transpose(1, 2)?;
let scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
} else {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match self.attn_logit_softcapping {
None => attn_weights,
Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
};
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, ()))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
pre_feedforward_layernorm: RmsNorm,
post_feedforward_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let pre_feedforward_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("pre_feedforward_layernorm"),
)?;
let post_feedforward_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_feedforward_layernorm"),
)?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
pre_feedforward_layernorm,
post_feedforward_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = xs.apply(&self.post_attention_layernorm)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.pre_feedforward_layernorm)?;
let xs = xs.apply(&self.mlp)?;
let xs = xs.apply(&self.post_feedforward_layernorm)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
final_logit_softcapping: Option<f64>,
device: Device,
dtype: DType,
hidden_size: usize,
sliding_window: Option<usize>,
}
impl Model {
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer =
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
final_logit_softcapping: cfg.final_logit_softcapping,
device: vb.device().clone(),
dtype: vb.dtype(),
hidden_size: cfg.hidden_size,
sliding_window: cfg.sliding_window,
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = match self.sliding_window {
None => (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect(),
Some(sliding_window) => (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect(),
};
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
let logits = xs
.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)?;
let logits = match self.final_logit_softcapping {
None => logits,
Some(sc) => ((logits / sc)?.tanh()? * sc)?,
};
Ok(logits)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -0,0 +1,595 @@
use crate::models::with_tracing::{linear_b as linear, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug, Clone)]
pub struct Config {
pub num_layers: usize,
pub padded_vocab_size: usize,
pub hidden_size: usize,
pub ffn_hidden_size: usize,
pub kv_channels: usize,
pub num_attention_heads: usize,
pub seq_length: usize,
pub layernorm_epsilon: f64,
pub rmsnorm: bool,
pub apply_residual_connection_post_layernorm: bool,
pub post_layer_norm: bool,
pub add_bias_linear: bool,
pub add_qkv_bias: bool,
pub bias_dropout_fusion: bool,
pub multi_query_attention: bool,
pub multi_query_group_num: usize,
pub apply_query_key_layer_scaling: bool,
pub attention_softmax_in_fp32: bool,
pub fp32_residual_connection: bool,
}
impl Config {
pub fn glm4() -> Self {
Self {
num_layers: 40,
padded_vocab_size: 151552,
hidden_size: 4096,
ffn_hidden_size: 13696,
kv_channels: 128,
num_attention_heads: 32,
seq_length: 8192,
layernorm_epsilon: 1e-5,
rmsnorm: true,
apply_residual_connection_post_layernorm: false,
post_layer_norm: true,
add_bias_linear: false,
add_qkv_bias: true,
bias_dropout_fusion: true,
multi_query_attention: true,
multi_query_group_num: 2,
apply_query_key_layer_scaling: true,
attention_softmax_in_fp32: true,
fp32_residual_connection: false,
}
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
cache: Tensor,
}
impl RotaryEmbedding {
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
let rotary_dim = cfg.kv_channels;
let n_elem = rotary_dim / 2;
let inv_freq: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
.to_dtype(dtype)?
.reshape((cfg.seq_length, 1))?;
let freqs = t.matmul(&inv_freq)?;
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
Ok(Self { cache })
}
fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (seqlen, _b, np, _hn) = xs.dims4()?;
let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
let rot_dim = cache.dim(D::Minus2)? * 2;
let (xs, xs_pass) = (
xs.narrow(D::Minus1, 0, rot_dim)?,
xs.narrow(D::Minus1, rot_dim, rot_dim)?,
);
let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
let (xshaped0, xshaped1) = (
xshaped.i((.., .., .., .., 0))?,
xshaped.i((.., .., .., .., 1))?,
);
let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
let xs_out = Tensor::stack(
&[
(xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
(xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
],
D::Minus1,
)?;
let xs_out = xs_out.flatten_from(3)?;
Tensor::cat(&[xs_out, xs_pass], D::Minus1)
}
}
#[derive(Debug, Clone)]
struct CoreAttention {
coeff: Option<f64>,
norm_factor: f64,
dtype: DType,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
Ok(m)
}
impl CoreAttention {
fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result<Self> {
let norm_factor = (cfg.kv_channels as f64).sqrt();
let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
let coeff = f64::max(1.0, layer_number as f64);
(norm_factor * coeff, Some(coeff))
} else {
(norm_factor, None)
};
Ok(Self {
coeff,
norm_factor,
dtype,
})
}
fn forward(
&self,
query_layer: &Tensor,
key_layer: &Tensor,
value_layer: &Tensor,
attention_mask: &Option<Tensor>,
) -> Result<Tensor> {
let output_size = (
query_layer.dim(1)?, // b
query_layer.dim(2)?, // np
query_layer.dim(0)?, // sq
key_layer.dim(0)?, // sk
);
let query_layer =
query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
let matmul_result = Tensor::matmul(
&query_layer.transpose(0, 1)?.contiguous()?,
&key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?,
)?;
let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
let matmul_result = match self.coeff {
None => matmul_result,
Some(coeff) => (matmul_result * coeff)?,
};
let attention_scores = match attention_mask {
Some(mask) => masked_fill(
&matmul_result,
&mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
f32::NEG_INFINITY,
self.dtype,
)?,
None => matmul_result,
};
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let output_size = (
value_layer.dim(1)?,
value_layer.dim(2)?,
query_layer.dim(0)?,
value_layer.dim(3)?,
);
let value_layer =
value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
let attention_probs =
attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
let context_layer = Tensor::matmul(
&attention_probs.contiguous()?,
&value_layer.transpose(0, 1)?.contiguous()?,
)?;
let context_layer = context_layer.reshape(output_size)?;
let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
context_layer.flatten_from(D::Minus2)
}
}
#[derive(Debug, Clone)]
struct SelfAttention {
query_key_value: Linear,
core_attention: CoreAttention,
dense: Linear,
multi_query_attention: bool,
num_attention_heads_per_partition: usize,
num_multi_query_groups_per_partition: usize,
hidden_size_per_attention_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl SelfAttention {
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let projection_size = cfg.kv_channels * cfg.num_attention_heads;
let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
let qkv_hidden_size = if cfg.multi_query_attention {
projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
} else {
3 * projection_size
};
let query_key_value = linear(
cfg.hidden_size,
qkv_hidden_size,
cfg.add_bias_linear || cfg.add_qkv_bias,
vb.pp("query_key_value"),
)?;
let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?;
let dense = linear(
cfg.hidden_size,
cfg.hidden_size,
cfg.add_bias_linear,
vb.pp("dense"),
)?;
Ok(Self {
query_key_value,
core_attention,
dense,
multi_query_attention: cfg.multi_query_attention,
num_attention_heads_per_partition: cfg.num_attention_heads,
num_multi_query_groups_per_partition: cfg.multi_query_group_num,
hidden_size_per_attention_head: cfg.kv_channels,
kv_cache: None,
})
}
fn reset_kv_cache(&mut self) {
self.kv_cache = None
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Option<Tensor>,
rotary_emb: &RotaryEmbedding,
) -> Result<Tensor> {
let mixed_x_layer = xs.apply(&self.query_key_value)?;
if !self.multi_query_attention {
candle::bail!("only multi_query_attention=true is supported")
}
let hpa = self.hidden_size_per_attention_head;
let query_layer =
mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
let key_layer = mixed_x_layer.narrow(
D::Minus1,
self.num_attention_heads_per_partition * hpa,
self.num_multi_query_groups_per_partition * hpa,
)?;
let value_layer = mixed_x_layer.narrow(
D::Minus1,
self.num_attention_heads_per_partition * hpa
+ self.num_multi_query_groups_per_partition * hpa,
self.num_multi_query_groups_per_partition * hpa,
)?;
let query_layer = query_layer.reshape((
query_layer.dim(0)?,
query_layer.dim(1)?,
self.num_attention_heads_per_partition,
hpa,
))?;
let key_layer = key_layer.reshape((
key_layer.dim(0)?,
key_layer.dim(1)?,
self.num_multi_query_groups_per_partition,
hpa,
))?;
let value_layer = value_layer.reshape((
value_layer.dim(0)?,
value_layer.dim(1)?,
self.num_multi_query_groups_per_partition,
hpa,
))?;
// Rotary embeddings.
let seqlen_offset = match &self.kv_cache {
None => 0,
Some((prev_k, _)) => prev_k.dim(0)?,
};
let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
// KV cache.
let (key_layer, value_layer) = match &self.kv_cache {
None => (key_layer, value_layer),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
(k, v)
}
};
self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
// Repeat KV.
let ratio =
self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
let key_layer = {
let (d0, d1, d2, d3) = key_layer.dims4()?;
key_layer
.unsqueeze(D::Minus2)?
.expand((d0, d1, d2, ratio, d3))?
.reshape((
d0,
d1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
))?
};
let value_layer = {
let (d0, d1, d2, d3) = value_layer.dims4()?;
value_layer
.unsqueeze(D::Minus2)?
.expand((d0, d1, d2, ratio, d3))?
.reshape((
d0,
d1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
))?
};
let context_layer =
self.core_attention
.forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
let output = context_layer.apply(&self.dense)?;
Ok(output)
}
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone)]
struct MLP {
dense_h_to_4h: Linear,
dense_4h_to_h: Linear,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense_h_to_4h = linear(
cfg.hidden_size,
cfg.ffn_hidden_size * 2,
cfg.add_bias_linear,
vb.pp("dense_h_to_4h"),
)?;
let dense_4h_to_h = linear(
cfg.ffn_hidden_size,
cfg.hidden_size,
cfg.add_bias_linear,
vb.pp("dense_4h_to_h"),
)?;
Ok(Self {
dense_4h_to_h,
dense_h_to_4h,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.dense_h_to_4h)?
.apply(&candle_nn::Activation::Swiglu)?
.apply(&self.dense_4h_to_h)
}
}
#[derive(Debug, Clone)]
struct Block {
input_layernorm: candle_nn::LayerNorm,
self_attention: SelfAttention,
post_attention_layernorm: candle_nn::LayerNorm,
mlp: MLP,
apply_residual_connection_post_layernorm: bool,
}
impl Block {
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let input_layernorm = if cfg.rmsnorm {
candle_nn::rms_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("input_layernorm"),
)?
.into_inner()
} else {
candle_nn::layer_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("input_layernorm"),
)?
};
let post_attention_layernorm = if cfg.rmsnorm {
candle_nn::rms_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("post_attention_layernorm"),
)?
.into_inner()
} else {
candle_nn::layer_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("post_attention_layernorm"),
)?
};
let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
Ok(Self {
input_layernorm,
self_attention,
post_attention_layernorm,
mlp,
apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
})
}
fn reset_kv_cache(&mut self) {
self.self_attention.reset_kv_cache()
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Option<Tensor>,
rotary_emb: &RotaryEmbedding,
) -> Result<Tensor> {
let layernorm_output = xs.apply(&self.input_layernorm)?;
let attention_output =
self.self_attention
.forward(&layernorm_output, attention_mask, rotary_emb)?;
let residual = if self.apply_residual_connection_post_layernorm {
&layernorm_output
} else {
xs
};
let layernorm_input = (residual + attention_output)?;
let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
let mlp_output = layernorm_output.apply(&self.mlp)?;
let residual = if self.apply_residual_connection_post_layernorm {
&layernorm_output
} else {
&layernorm_input
};
mlp_output + residual
}
}
#[derive(Debug, Clone)]
struct Transformer {
layers: Vec<Block>,
final_layernorm: Option<candle_nn::LayerNorm>,
rotary_emb: RotaryEmbedding,
}
impl Transformer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_l = vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_layers);
for layer_index in 0..cfg.num_layers {
let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
layers.push(block)
}
let final_layernorm = if cfg.post_layer_norm {
let ln = if cfg.rmsnorm {
candle_nn::rms_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("final_layernorm"),
)?
.into_inner()
} else {
candle_nn::layer_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("final_layernorm"),
)?
};
Some(ln)
} else {
None
};
let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
Ok(Self {
layers,
final_layernorm,
rotary_emb,
})
}
fn reset_kv_cache(&mut self) {
for block in self.layers.iter_mut() {
block.reset_kv_cache()
}
}
fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for block in self.layers.iter_mut() {
xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
}
match self.final_layernorm.as_ref() {
None => Ok(xs),
Some(ln) => xs.apply(ln),
}
}
}
#[derive(Debug, Clone)]
struct Embedding {
word_embeddings: candle_nn::Embedding,
fp32_residual_connection: bool,
}
impl Embedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let word_embeddings = candle_nn::embedding(
cfg.padded_vocab_size,
cfg.hidden_size,
vb.pp("word_embeddings"),
)?;
Ok(Self {
word_embeddings,
fp32_residual_connection: cfg.fp32_residual_connection,
})
}
}
impl Module for Embedding {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
if self.fp32_residual_connection {
xs.to_dtype(candle::DType::F32)
} else {
xs.contiguous()
}
}
}
#[derive(Debug, Clone)]
pub struct Model {
embedding: Embedding,
encoder: Transformer,
output_layer: Linear,
}
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device)
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("transformer");
let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
let output_layer = linear(
cfg.hidden_size,
cfg.padded_vocab_size,
false,
vb.pp("output_layer"),
)?;
Ok(Self {
embedding,
encoder,
output_layer,
})
}
pub fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache()
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let (_b_size, seq_len) = xs.dims2()?;
let input_embeds = xs.apply(&self.embedding)?;
let attention_mask = if seq_len <= 1 {
None
} else {
Some(get_mask(seq_len, xs.device())?)
};
let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
Ok(lm_logits)
}
}

View File

@ -0,0 +1,458 @@
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::{collections::HashMap, f32::consts::PI};
pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone, serde::Deserialize, Default)]
pub enum GraniteRopeType {
#[serde(rename = "granite")]
Granite,
#[default]
#[serde(rename = "default")]
Default,
}
#[derive(Debug, Clone, serde::Deserialize, Default)]
pub struct GraniteRopeConfig {
pub factor: f32,
pub low_freq_factor: f32,
pub high_freq_factor: f32,
pub original_max_position_embeddings: usize,
pub rope_type: GraniteRopeType,
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(untagged)]
pub enum GraniteEosToks {
Single(u32),
Multiple(Vec<u32>),
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct GraniteConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: Option<usize>,
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<GraniteEosToks>,
pub rope_scaling: Option<GraniteRopeConfig>,
pub max_position_embeddings: usize,
}
impl GraniteConfig {
pub fn num_key_value_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}
}
fn default_rope() -> f32 {
10_000.0
}
impl GraniteConfig {
pub fn into_config(self, use_flash_attn: bool) -> Config {
Config {
hidden_size: self.hidden_size,
intermediate_size: self.intermediate_size,
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
num_key_value_heads: self.num_key_value_heads(),
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
rope_scaling: self.rope_scaling,
max_position_embeddings: self.max_position_embeddings,
}
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<GraniteEosToks>,
pub rope_scaling: Option<GraniteRopeConfig>,
pub max_position_embeddings: usize,
}
#[derive(Debug, Clone)]
pub struct Cache {
masks: HashMap<usize, Tensor>,
pub use_kv_cache: bool,
kvs: Vec<Option<(Tensor, Tensor)>>,
cos: Tensor,
sin: Tensor,
device: Device,
}
fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.collect()
}
impl Cache {
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let theta = match &config.rope_scaling {
None
| Some(GraniteRopeConfig {
rope_type: GraniteRopeType::Default,
..
}) => calculate_default_inv_freq(config),
Some(rope_scaling) => {
let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.low_freq_factor;
let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.high_freq_factor;
calculate_default_inv_freq(config)
.into_iter()
.map(|freq| {
let wavelen = 2. * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / rope_scaling.factor
} else {
let smooth = (rope_scaling.original_max_position_embeddings as f32
/ wavelen
- rope_scaling.low_freq_factor)
/ (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
(1. - smooth) * freq / rope_scaling.factor + smooth * freq
}
})
.collect::<Vec<_>>()
}
};
let theta = Tensor::new(theta, device)?;
let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
.to_dtype(DType::F32)?
.reshape((config.max_position_embeddings, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
masks: HashMap::new(),
use_kv_cache,
kvs: vec![None; config.num_hidden_layers],
device: device.clone(),
cos,
sin,
})
}
fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
}
#[derive(Debug, Clone)]
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: usize,
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
max_position_embeddings: usize,
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
candle_nn::rotary_emb::rope(x, &cos, &sin)
}
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let mut v = v
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
if cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
if k_seq_len > self.max_position_embeddings {
k = k
.narrow(
D::Minus1,
k_seq_len - self.max_position_embeddings,
self.max_position_embeddings,
)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > 2 * self.max_position_embeddings {
v = v
.narrow(
D::Minus1,
v_seq_len - self.max_position_embeddings,
self.max_position_embeddings,
)?
.contiguous()?
}
}
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = if seq_len == 1 {
att
} else {
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
masked_fill(&att, &mask, f32::NEG_INFINITY)?
};
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_attention_heads: cfg.num_attention_heads,
num_key_value_heads: cfg.num_key_value_heads,
head_dim: cfg.hidden_size / cfg.num_attention_heads,
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
max_position_embeddings: cfg.max_position_embeddings,
})
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
span: tracing::Span,
}
impl Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "mlp");
let h_size = cfg.hidden_size;
let i_size = cfg.intermediate_size;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
Ok(Self {
c_fc1,
c_fc2,
c_proj,
span,
})
}
}
#[derive(Debug, Clone)]
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
span: tracing::Span,
}
impl Block {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
rms_1,
attn,
rms_2,
mlp,
span,
})
}
}
#[derive(Debug, Clone)]
pub struct Granite {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
}
impl Granite {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
.collect();
Ok(Self {
wte,
blocks,
ln_f,
lm_head,
})
}
}

View File

@ -344,7 +344,7 @@ impl BertEncoder {
candle::bail!("only alibi is supported as a position-embedding-type")
}
let layers = (0..cfg.num_hidden_layers)
.map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg))
.map(|index| BertLayer::new(vb.pp(format!("layer.{index}")), cfg))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;

View File

@ -507,7 +507,7 @@ impl Llama {
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
.collect();
Ok(Self {

View File

@ -354,7 +354,7 @@ impl Llama {
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
.collect();
Ok(Self {
wte,

View File

@ -0,0 +1,670 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
use candle_nn::{Conv1d, VarBuilder};
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Norm {
WeightNorm,
SpectralNorm,
TimeGroupNorm,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum PadMode {
Constant,
Reflect,
Replicate,
}
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = if vb.contains_tensor("weight") {
vb.get((out_c, in_c, kernel_size), "weight")?
} else {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
};
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(Conv1d::new(weight, bias, config))
}
#[derive(Debug, Clone)]
pub struct NormConv1d {
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
span: tracing::Span,
}
impl NormConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
causal: bool,
norm: Option<Norm>,
bias: bool,
cfg: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Self> {
let conv = match norm {
None | Some(Norm::TimeGroupNorm) => {
if bias {
candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
} else {
candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
}
}
Some(Norm::WeightNorm) => {
conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
}
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
};
let norm = match norm {
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
Some(Norm::TimeGroupNorm) => {
if causal {
candle::bail!("GroupNorm doesn't support causal evaluation.")
}
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(norm)
}
};
Ok(Self {
conv,
norm,
span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
})
}
}
impl Module for NormConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = xs.apply(&self.conv)?;
match self.norm.as_ref() {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Debug, Clone)]
pub struct NormConvTranspose1d {
ws: Tensor,
bs: Option<Tensor>,
k_size: usize,
stride: usize,
groups: usize,
norm: Option<candle_nn::GroupNorm>,
span: tracing::Span,
}
impl NormConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
causal: bool,
norm: Option<Norm>,
bias: bool,
stride: usize,
groups: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("conv");
let bs = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
let ws = match norm {
None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
Some(Norm::WeightNorm) => {
if vb.contains_tensor("weight") {
vb.get((in_c, out_c, k_size), "weight")?
} else {
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
}
}
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
};
let (ws, groups) = if groups == out_c && in_c == out_c {
let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
let ws = ws
.repeat((1, out_c, 1))?
.mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
(ws, 1)
} else {
(ws, groups)
};
let norm = match norm {
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
Some(Norm::TimeGroupNorm) => {
if causal {
candle::bail!("GroupNorm doesn't support causal evaluation.")
}
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(norm)
}
};
Ok(Self {
ws,
bs,
k_size,
stride,
groups,
norm,
span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
})
}
}
impl Module for NormConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
// conv-transpose1d seems to be broken on metal after enough iterations. Causing
// the following error:
// _status < MTLCommandBufferStatusCommitted >
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
// This is now fixed in candle.
let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
let xs = match &self.bs {
None => xs,
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
xs.broadcast_add(&bias)?
}
};
match self.norm.as_ref() {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
fn get_extra_padding_for_conv1d(
xs: &Tensor,
k_size: usize,
stride: usize,
padding_total: usize,
) -> Result<usize> {
let len = xs.dim(D::Minus1)?;
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
let ideal_len =
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
Ok(ideal_len.saturating_sub(len))
}
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
match mode {
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
}
}
fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
let len = xs.dim(D::Minus1)?;
if len < unpad_l + unpad_r {
candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
}
xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
}
#[derive(Debug, Clone)]
pub struct StreamableConv1d {
conv: NormConv1d,
causal: bool,
pad_mode: PadMode,
state_prev_xs: StreamTensor,
left_pad_applied: bool,
kernel_size: usize,
span: tracing::Span,
}
impl StreamableConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
stride: usize,
dilation: usize,
groups: usize,
bias: bool,
causal: bool,
norm: Option<Norm>,
pad_mode: PadMode,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::Conv1dConfig {
padding: 0,
stride,
dilation,
groups,
};
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
if k_size < stride {
candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
}
Ok(Self {
conv,
causal,
pad_mode,
state_prev_xs: StreamTensor::empty(),
left_pad_applied: false,
kernel_size: k_size,
span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
})
}
}
impl Module for StreamableConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b, _t, _c) = xs.dims3()?;
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.conv.config();
// Effective kernel size with dilations.
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
let extra_padding =
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
let xs = if self.causal {
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
pad1d(
xs,
padding_left,
padding_right + extra_padding,
self.pad_mode,
)?
};
xs.apply(&self.conv)
}
}
impl StreamingModule for StreamableConv1d {
fn reset_state(&mut self) {
self.state_prev_xs.reset();
self.left_pad_applied = false;
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let xs = match xs.as_option() {
None => return Ok(().into()),
Some(xs) => xs.clone(),
};
let xs = if self.left_pad_applied {
xs
} else {
self.left_pad_applied = true;
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.conv.config();
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
pad1d(&xs, padding_total, 0, self.pad_mode)?
};
let cfg = self.conv.conv.config();
let stride = cfg.stride;
let dilation = cfg.dilation;
let kernel = (self.kernel_size - 1) * dilation + 1;
let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
let seq_len = xs.seq_len(D::Minus1)?;
let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
if num_frames > 0 {
let offset = num_frames * stride;
self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
let in_l = (num_frames - 1) * stride + kernel;
let xs = xs.narrow(D::Minus1, 0, in_l)?;
// We apply the underlying convtr directly rather than through forward so as
// not to apply any padding here.
xs.apply(&self.conv.conv)
} else {
self.state_prev_xs = xs;
Ok(StreamTensor::empty())
}
}
}
#[derive(Debug, Clone)]
pub struct StreamableConvTranspose1d {
convtr: NormConvTranspose1d,
causal: bool,
state_prev_ys: StreamTensor,
kernel_size: usize,
span: tracing::Span,
}
impl StreamableConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_c: usize,
out_c: usize,
k_size: usize,
stride: usize,
groups: usize,
bias: bool,
causal: bool,
norm: Option<Norm>,
vb: VarBuilder,
) -> Result<Self> {
let convtr =
NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
Ok(Self {
convtr,
causal,
kernel_size: k_size,
state_prev_ys: StreamTensor::empty(),
span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
})
}
}
impl Module for StreamableConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let k_size = self.convtr.k_size;
let stride = self.convtr.stride;
let padding_total = k_size.saturating_sub(stride);
let xs = xs.apply(&self.convtr)?;
if self.causal {
// This corresponds to trim_right_ratio = 1.
unpad1d(&xs, 0, padding_total)
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
unpad1d(&xs, padding_left, padding_right)
}
}
}
impl StreamingModule for StreamableConvTranspose1d {
fn reset_state(&mut self) {
self.state_prev_ys.reset()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let xs = match xs.as_option() {
Some(xs) => xs,
None => return Ok(StreamTensor::empty()),
};
let stride = self.convtr.stride;
// We apply the underlying convtr directly rather than through forward so as
// not to apply any padding here.
let ys = self.convtr.forward(xs)?;
let ot = ys.dim(D::Minus1)?;
let ys = match self.state_prev_ys.as_option() {
None => ys,
Some(prev_ys) => {
let pt = prev_ys.dim(D::Minus1)?;
// Remove the bias as it will be applied multiple times.
let prev_ys = match &self.convtr.bs {
None => prev_ys.clone(),
Some(bias) => {
let bias = bias.reshape((1, (), 1))?;
prev_ys.broadcast_sub(&bias)?
}
};
let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
Tensor::cat(&[ys1, ys2], D::Minus1)?
}
};
let invalid_steps = self.kernel_size - stride;
let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
self.state_prev_ys = prev_ys;
Ok(ys)
}
}
#[derive(Debug, Clone)]
pub struct ConvDownsample1d {
conv: StreamableConv1d,
}
impl ConvDownsample1d {
pub fn new(
stride: usize,
dim: usize,
causal: bool,
learnt: bool,
vb: VarBuilder,
) -> Result<Self> {
if !learnt {
candle::bail!("only learnt=true is supported")
}
let conv = StreamableConv1d::new(
/* in_c */ dim,
/* out_c */ dim,
/* k_size_c */ 2 * stride,
/* stride */ stride,
/* dilation */ 1,
/* groups */ 1, // channel_wise = false
/* bias */ false,
/* causal */ causal,
/* norm */ None,
/* pad_mode */ PadMode::Replicate,
vb,
)?;
Ok(Self { conv })
}
}
impl Module for ConvDownsample1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.conv)
}
}
impl StreamingModule for ConvDownsample1d {
fn reset_state(&mut self) {
self.conv.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
self.conv.step(xs)
}
}
#[derive(Debug, Clone)]
pub struct ConvTrUpsample1d {
convtr: StreamableConvTranspose1d,
}
impl ConvTrUpsample1d {
pub fn new(
stride: usize,
dim: usize,
causal: bool,
learnt: bool,
vb: VarBuilder,
) -> Result<Self> {
if !learnt {
candle::bail!("only learnt=true is supported")
}
let convtr = StreamableConvTranspose1d::new(
dim,
dim,
/* k_size */ 2 * stride,
/* stride */ stride,
/* groups */ dim,
/* bias */ false,
/* causal */ causal,
/* norm */ None,
vb,
)?;
Ok(Self { convtr })
}
}
impl Module for ConvTrUpsample1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.convtr)
}
}
impl StreamingModule for ConvTrUpsample1d {
fn reset_state(&mut self) {
self.convtr.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
self.convtr.step(xs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle::IndexOp;
fn run_conv1d(
k_size: usize,
stride: usize,
dilation: usize,
step_size: usize,
len: usize,
bias: bool,
) -> Result<()> {
// TODO: We should ensure for the seed to be constant when running these tests.
let dev = &candle::Device::Cpu;
let vm = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
let conv1d = StreamableConv1d::new(
/* in_c */ 2,
/* out_c */ 3,
/* k_size */ k_size,
/* stride */ stride,
/* dilation */ dilation,
/* groups */ 1,
/* bias */ bias,
/* causal */ true,
/* norm */ None,
/* pad_mode */ PadMode::Constant,
vb,
)?;
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
let ys = conv1d.forward(&xs)?;
let mut conv1d = conv1d;
let mut ys_steps = vec![];
for idx in 0..len {
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
let ys = conv1d.step(&xs.into())?;
if let Some(ys) = ys.as_option() {
ys_steps.push(ys.clone())
}
}
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
let diff = (&ys - &ys_steps)?
.abs()?
.flatten_all()?
.max(0)?
.to_vec0::<f32>()?;
if diff > 1e-5 {
println!("{xs}");
println!("{ys}");
println!("{ys_steps}");
candle::bail!("larger diff than expected {diff}")
}
Ok(())
}
fn run_conv_tr1d(
k_size: usize,
stride: usize,
step_size: usize,
len: usize,
bias: bool,
) -> Result<()> {
// TODO: We should ensure for the seed to be constant when running these tests.
let dev = &candle::Device::Cpu;
let vm = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
let conv1d = StreamableConvTranspose1d::new(
/* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
/* stride */ stride, /* groups */ 1, /* bias */ bias,
/* causal */ true, /* norm */ None, vb,
)?;
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
let ys = conv1d.forward(&xs)?;
let mut conv1d = conv1d;
let mut ys_steps = vec![];
for idx in 0..len {
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
let ys = conv1d.step(&xs.into())?;
if let Some(ys) = ys.as_option() {
ys_steps.push(ys.clone())
}
}
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
let diff = (&ys - &ys_steps)?
.abs()?
.flatten_all()?
.max(0)?
.to_vec0::<f32>()?;
if diff > 1e-5 {
println!("{xs}");
println!("{ys}");
println!("{ys_steps}");
candle::bail!("larger diff than expected {diff}")
}
Ok(())
}
#[test]
fn conv1d() -> Result<()> {
for step_size in [1, 2, 3] {
for bias in [false, true] {
run_conv1d(1, 1, 1, step_size, 5, bias)?;
run_conv1d(2, 1, 1, step_size, 5, bias)?;
run_conv1d(2, 2, 1, step_size, 6, bias)?;
run_conv1d(3, 2, 1, step_size, 8, bias)?;
run_conv1d(3, 2, 2, step_size, 8, bias)?;
}
}
Ok(())
}
#[test]
fn conv_tr1d() -> Result<()> {
for step_size in [1, 2, 3] {
for bias in [false, true] {
run_conv_tr1d(1, 1, step_size, 5, bias)?;
run_conv_tr1d(2, 1, step_size, 5, bias)?;
run_conv_tr1d(3, 1, step_size, 5, bias)?;
run_conv_tr1d(3, 2, step_size, 5, bias)?;
}
}
Ok(())
}
}

View File

@ -0,0 +1,229 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use super::{conv, quantization, seanet, transformer};
use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor};
use candle_nn::VarBuilder;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum ResampleMethod {
Conv,
Interpolate,
}
#[derive(Debug, Clone)]
pub struct Config {
pub channels: usize,
pub sample_rate: f64,
pub frame_rate: f64,
pub renormalize: bool,
pub resample_method: ResampleMethod,
pub seanet: seanet::Config,
pub transformer: transformer::Config,
pub quantizer_n_q: usize,
pub quantizer_bins: usize,
pub quantizer_dim: usize,
}
impl Config {
// /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml
pub fn v0_1(num_codebooks: Option<usize>) -> Self {
let seanet_cfg = seanet::Config {
dimension: 512,
channels: 1,
causal: true,
n_filters: 64,
n_residual_layers: 1,
activation: candle_nn::Activation::Elu(1.),
compress: 2,
dilation_base: 2,
disable_norm_outer_blocks: 0,
final_activation: None,
kernel_size: 7,
residual_kernel_size: 3,
last_kernel_size: 3,
lstm: 0,
norm: conv::Norm::WeightNorm,
pad_mode: conv::PadMode::Constant,
ratios: vec![8, 6, 5, 4],
true_skip: true,
};
let transformer_cfg = transformer::Config {
d_model: seanet_cfg.dimension,
num_heads: 8,
num_layers: 8,
causal: true,
norm_first: true,
bias_ff: false,
bias_attn: false,
layer_scale: Some(0.01),
context: 250,
conv_kernel_size: 5,
use_conv_bias: true,
use_conv_block: false,
cross_attention: false,
max_period: 10000,
gating: None,
norm: super::NormType::LayerNorm,
positional_embedding: transformer::PositionalEmbedding::Rope,
dim_feedforward: 2048,
kv_repeat: 1,
conv_layout: true, // see builders.py
max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins.
};
Config {
channels: 1,
sample_rate: 24_000.,
frame_rate: 12.5,
renormalize: true,
resample_method: ResampleMethod::Conv,
seanet: seanet_cfg,
transformer: transformer_cfg,
quantizer_n_q: num_codebooks.unwrap_or(16),
quantizer_bins: 2048,
quantizer_dim: 256,
}
}
}
#[derive(Debug, Clone)]
pub struct Encodec {
encoder: seanet::SeaNetEncoder,
decoder: seanet::SeaNetDecoder,
encoder_transformer: transformer::ProjectedTransformer,
decoder_transformer: transformer::ProjectedTransformer,
downsample: conv::ConvDownsample1d,
upsample: conv::ConvTrUpsample1d,
quantizer: quantization::SplitResidualVectorQuantizer,
config: Config,
}
impl Encodec {
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let dim = cfg.seanet.dimension;
let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
let encoder_transformer = transformer::ProjectedTransformer::new(
dim,
&[dim],
&cfg.transformer,
vb.pp("encoder_transformer"),
)?;
let decoder_transformer = transformer::ProjectedTransformer::new(
dim,
&[dim],
&cfg.transformer,
vb.pp("decoder_transformer"),
)?;
let quantizer = quantization::SplitResidualVectorQuantizer::new(
/* dim */ cfg.quantizer_dim,
/* input_dim */ Some(dim),
/* output_dim */ Some(dim),
/* n_q */ cfg.quantizer_n_q,
/* bins */ cfg.quantizer_bins,
vb.pp("quantizer"),
)?;
let encoder_frame_rate =
cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
// `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate.
let downsample = conv::ConvDownsample1d::new(
/* stride */ downsample_stride,
/* dim */ dim,
/* causal */ true,
/* learnt */ true,
vb.pp("downsample"),
)?;
let upsample = conv::ConvTrUpsample1d::new(
/* stride */ downsample_stride,
/* dim */ dim,
/* causal */ true,
/* learnt */ true,
vb.pp("upsample"),
)?;
Ok(Self {
encoder,
decoder,
encoder_transformer,
decoder_transformer,
quantizer,
downsample,
upsample,
config: cfg,
})
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
let xs = self.encoder.forward(xs)?;
self.encoder_transformer.reset_state();
let xs = self.encoder_transformer.forward(&xs)?;
let xs = &xs[0];
xs.apply(&self.downsample)
}
pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
let xs = self.encoder.forward(xs)?;
self.encoder_transformer.reset_state();
let xs = self.encoder_transformer.forward(&xs)?;
let xs = &xs[0];
let xs = xs.apply(&self.downsample)?;
let codes = self.quantizer.encode(&xs)?;
Ok(codes)
}
pub fn encode_step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let xs = self.encoder.step(xs)?;
let xs = self.encoder_transformer.step(&xs)?;
let xs = self.downsample.step(&xs)?;
match xs.as_option() {
None => Ok(().into()),
Some(xs) => {
let codes = self.quantizer.encode(xs)?;
Ok(codes.into())
}
}
}
pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
let emb = self.quantizer.decode(codes)?;
let emb = emb.apply(&self.upsample)?;
self.decoder_transformer.reset_state();
let outs = self.decoder_transformer.forward(&emb)?;
let out = &outs[0];
self.decoder.forward(out)
}
pub fn decode_step(&mut self, codes: &StreamTensor) -> Result<StreamTensor> {
let emb = match codes.as_option() {
Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
None => StreamTensor::empty(),
};
let emb = self.upsample.step(&emb)?;
let out = self.decoder_transformer.step(&emb)?;
self.decoder.step(&out)
}
pub fn reset_state(&mut self) {
self.encoder.reset_state();
self.encoder_transformer.reset_state();
self.decoder.reset_state();
self.decoder_transformer.reset_state();
self.upsample.reset_state();
}
}
pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Encodec> {
let vb =
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
let cfg = Config::v0_1(num_codebooks);
let encodec = Encodec::new(cfg, vb)?;
Ok(encodec)
}

View File

@ -0,0 +1,22 @@
// Adapted from the reference implementation at:
// https://github.com/kyutai-labs/moshi
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
pub use candle;
pub use candle_nn;
pub mod conv;
pub mod encodec;
pub mod quantization;
pub mod seanet;
pub mod transformer;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum NormType {
RmsNorm,
LayerNorm,
}
pub use encodec::{load, Config, Encodec as Model};

View File

@ -0,0 +1,404 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{linear, Linear, VarBuilder};
struct CodebookEncode;
impl candle::CustomOp2 for CodebookEncode {
fn name(&self) -> &'static str {
"cb"
}
fn cpu_fwd(
&self,
lhs_storage: &candle::CpuStorage,
lhs_layout: &Layout,
rhs_storage: &candle::CpuStorage,
rhs_layout: &Layout,
) -> Result<(candle::CpuStorage, Shape)> {
use rayon::prelude::*;
let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
if lhs_dim2 != rhs_dim2 {
candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
}
if lhs_dim2 == 0 {
candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
}
let lhs = match lhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
Some((o1, o2)) => {
let slice = lhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let rhs = match rhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
Some((o1, o2)) => {
let slice = rhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let dst = (0..lhs_dim1)
.into_par_iter()
.map(|idx1| {
let mut where_min = 0;
let mut min_dist = f32::INFINITY;
let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
for idx2 in 0..rhs_dim1 {
let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
let mut dist = 0f32;
for (a, b) in lhs.iter().zip(rhs.iter()) {
dist += (a - b) * (a - b)
}
if dist < min_dist {
min_dist = dist;
where_min = idx2;
}
}
where_min as u32
})
.collect();
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (lhs_dim1,).into()))
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct EuclideanCodebook {
initialized: Tensor,
cluster_usage: Tensor,
embedding_sum: Tensor,
embedding: Tensor,
c2: Tensor,
epsilon: f64,
dim: usize,
span_encode: tracing::Span,
span_decode: tracing::Span,
}
impl EuclideanCodebook {
pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
let epsilon = 1e-5;
let initialized = vb.get(1, "initialized")?;
let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?;
let embedding = {
let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
embedding_sum.broadcast_div(&cluster_usage)?
};
let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
Ok(Self {
initialized,
cluster_usage,
embedding_sum,
embedding,
c2,
epsilon,
dim,
span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
})
}
pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
// TODO: avoid repeating this.
let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
// Manual cdist implementation.
let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
let dists = diff.sqr()?.sum(D::Minus1)?;
let codes = dists.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let dot_prod = xs.matmul(&self.embedding.t()?)?;
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
codes.reshape(target_shape)
}
pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
let _enter = self.span_decode.enter();
// let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?;
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.dim);
let indexes = indexes.flatten_all()?;
let values = self.embedding.index_select(&indexes, 0)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct VectorQuantization {
project_in: Option<Linear>,
project_out: Option<Linear>,
codebook: EuclideanCodebook,
}
impl VectorQuantization {
pub fn new(
dim: usize,
codebook_size: usize,
codebook_dim: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let codebook_dim = codebook_dim.unwrap_or(dim);
let (project_in, project_out) = if codebook_dim == dim {
(None, None)
} else {
let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
(Some(p_in), Some(p_out))
};
let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?;
Ok(Self {
project_in,
project_out,
codebook,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.t()?.apply(&self.project_in.as_ref())?;
self.codebook.encode_slow(&xs)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let quantized = self.codebook.decode(codes)?;
let quantized = match &self.project_out {
None => quantized,
Some(p) => quantized.apply(p)?,
};
quantized.t()
}
}
#[derive(Debug, Clone)]
pub struct ResidualVectorQuantization {
layers: Vec<VectorQuantization>,
}
impl ResidualVectorQuantization {
pub fn new(
n_q: usize,
dim: usize,
codebook_size: usize,
codebook_dim: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let vb = vb.pp("layers");
let mut layers = Vec::with_capacity(n_q);
for i in 0..n_q {
let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
layers.push(layer)
}
Ok(Self { layers })
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut codes = Vec::with_capacity(self.layers.len());
let mut residual = xs.clone();
for layer in self.layers.iter() {
let indices = layer.encode(&residual)?;
let quantized = layer.decode(&indices)?;
residual = (residual - quantized)?;
codes.push(indices)
}
Tensor::stack(&codes, 0)
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
if self.layers.is_empty() {
candle::bail!("empty layers in ResidualVectorQuantization")
}
if self.layers.len() != xs.dim(0)? {
candle::bail!(
"mismatch between the number of layers {} and the code shape {:?}",
self.layers.len(),
xs.shape()
)
}
let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
for (i, layer) in self.layers.iter().enumerate().skip(1) {
let xs = xs.i(i)?;
quantized = (quantized + layer.decode(&xs))?
}
Ok(quantized)
}
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct ResidualVectorQuantizer {
vq: ResidualVectorQuantization,
input_proj: Option<candle_nn::Conv1d>,
output_proj: Option<candle_nn::Conv1d>,
}
impl ResidualVectorQuantizer {
pub fn new(
dim: usize,
input_dim: Option<usize>,
output_dim: Option<usize>,
n_q: usize,
bins: usize,
force_projection: bool,
vb: VarBuilder,
) -> Result<Self> {
let input_dim = input_dim.unwrap_or(dim);
let output_dim = output_dim.unwrap_or(dim);
let input_proj = if input_dim == dim && !force_projection {
None
} else {
let c = candle_nn::conv1d_no_bias(
input_dim,
dim,
1,
Default::default(),
vb.pp("input_proj"),
)?;
Some(c)
};
let output_proj = if output_dim == dim && !force_projection {
None
} else {
let c = candle_nn::conv1d_no_bias(
dim,
output_dim,
1,
Default::default(),
vb.pp("output_proj"),
)?;
Some(c)
};
let vq = ResidualVectorQuantization::new(
n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb,
)?;
Ok(Self {
vq,
input_proj,
output_proj,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
codes.transpose(0, 1)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
// codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
let codes = codes.transpose(0, 1)?;
let quantized = self.vq.decode(&codes)?;
match &self.output_proj {
None => Ok(quantized),
Some(p) => quantized.apply(p),
}
}
}
// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just
// concatenate the indexes.
#[derive(Debug, Clone)]
pub struct SplitResidualVectorQuantizer {
rvq_first: ResidualVectorQuantizer,
rvq_rest: ResidualVectorQuantizer,
n_q: usize,
span_encode: tracing::Span,
span_decode: tracing::Span,
}
impl SplitResidualVectorQuantizer {
pub fn new(
dim: usize,
input_dim: Option<usize>,
output_dim: Option<usize>,
n_q: usize,
bins: usize,
vb: VarBuilder,
) -> Result<Self> {
let rvq_first = ResidualVectorQuantizer::new(
dim,
input_dim,
output_dim,
1,
bins,
true,
vb.pp("semantic_residual_vector_quantizer"),
)?;
let rvq_rest = ResidualVectorQuantizer::new(
dim,
input_dim,
output_dim,
n_q - 1,
bins,
true,
vb.pp("acoustic_residual_vector_quantizer"),
)?;
let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
Ok(Self {
rvq_first,
rvq_rest,
n_q,
span_encode,
span_decode,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span_encode.enter();
let codes = self.rvq_first.encode(xs)?;
if self.n_q > 1 {
// We encode xs again here rather than the residual. The decomposition is not
// hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens
// for rvq_rest.
let rest_codes = self.rvq_rest.encode(xs)?;
Tensor::cat(&[codes, rest_codes], 1)
} else {
Ok(codes)
}
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
// codes is [B, K, T], with T frames, K nb of codebooks.
let _enter = self.span_decode.enter();
let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
let quantized = if self.n_q > 1 {
(quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
} else {
quantized
};
Ok(quantized)
}
}

View File

@ -0,0 +1,465 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor};
use candle_nn::VarBuilder;
use super::conv::{StreamableConv1d, StreamableConvTranspose1d};
#[derive(Debug, Clone)]
pub struct Config {
pub dimension: usize,
pub channels: usize,
pub causal: bool,
pub n_filters: usize,
pub n_residual_layers: usize,
pub ratios: Vec<usize>,
pub activation: candle_nn::Activation,
pub norm: super::conv::Norm,
pub kernel_size: usize,
pub residual_kernel_size: usize,
pub last_kernel_size: usize,
pub dilation_base: usize,
pub pad_mode: super::conv::PadMode,
pub true_skip: bool,
pub compress: usize,
pub lstm: usize,
pub disable_norm_outer_blocks: usize,
pub final_activation: Option<candle_nn::Activation>,
}
#[derive(Debug, Clone)]
pub struct SeaNetResnetBlock {
block: Vec<StreamableConv1d>,
shortcut: Option<StreamableConv1d>,
activation: candle_nn::Activation,
skip_op: candle::StreamingBinOp,
span: tracing::Span,
}
impl SeaNetResnetBlock {
#[allow(clippy::too_many_arguments)]
pub fn new(
dim: usize,
k_sizes_and_dilations: &[(usize, usize)],
activation: candle_nn::Activation,
norm: Option<super::conv::Norm>,
causal: bool,
pad_mode: super::conv::PadMode,
compress: usize,
true_skip: bool,
vb: VarBuilder,
) -> Result<Self> {
let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
let hidden = dim / compress;
let vb_b = vb.pp("block");
for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
let in_c = if i == 0 { dim } else { hidden };
let out_c = if i == k_sizes_and_dilations.len() - 1 {
dim
} else {
hidden
};
let c = StreamableConv1d::new(
in_c,
out_c,
/* k_size */ *k_size,
/* stride */ 1,
/* dilation */ *dilation,
/* groups */ 1,
/* bias */ true,
/* causal */ causal,
/* norm */ norm,
/* pad_mode */ pad_mode,
vb_b.pp(2 * i + 1),
)?;
block.push(c)
}
let shortcut = if true_skip {
None
} else {
let c = StreamableConv1d::new(
dim,
dim,
/* k_size */ 1,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ causal,
/* norm */ norm,
/* pad_mode */ pad_mode,
vb.pp("shortcut"),
)?;
Some(c)
};
Ok(Self {
block,
shortcut,
activation,
skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
})
}
}
impl Module for SeaNetResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut ys = xs.clone();
for block in self.block.iter() {
ys = ys.apply(&self.activation)?.apply(block)?;
}
match self.shortcut.as_ref() {
None => ys + xs,
Some(shortcut) => ys + xs.apply(shortcut),
}
}
}
impl StreamingModule for SeaNetResnetBlock {
fn reset_state(&mut self) {
for block in self.block.iter_mut() {
block.reset_state()
}
if let Some(shortcut) = self.shortcut.as_mut() {
shortcut.reset_state()
}
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut ys = xs.clone();
for block in self.block.iter_mut() {
ys = block.step(&ys.apply(&self.activation)?)?;
}
match self.shortcut.as_ref() {
None => self.skip_op.step(&ys, xs),
Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?),
}
}
}
#[derive(Debug, Clone)]
struct EncoderLayer {
residuals: Vec<SeaNetResnetBlock>,
downsample: StreamableConv1d,
}
#[derive(Debug, Clone)]
pub struct SeaNetEncoder {
init_conv1d: StreamableConv1d,
activation: candle_nn::Activation,
layers: Vec<EncoderLayer>,
final_conv1d: StreamableConv1d,
span: tracing::Span,
}
impl SeaNetEncoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.lstm > 0 {
candle::bail!("seanet lstm is not supported")
}
let n_blocks = 2 + cfg.ratios.len();
let mut mult = 1usize;
let init_norm = if cfg.disable_norm_outer_blocks >= 1 {
None
} else {
Some(cfg.norm)
};
let mut layer_idx = 0;
let vb = vb.pp("layers");
let init_conv1d = StreamableConv1d::new(
cfg.channels,
mult * cfg.n_filters,
cfg.kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ init_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx),
)?;
layer_idx += 1;
let mut layers = Vec::with_capacity(cfg.ratios.len());
for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
let norm = if cfg.disable_norm_outer_blocks >= i + 2 {
None
} else {
Some(cfg.norm)
};
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
for j in 0..cfg.n_residual_layers {
let resnet_block = SeaNetResnetBlock::new(
mult * cfg.n_filters,
&[
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
(1, 1),
],
cfg.activation,
norm,
cfg.causal,
cfg.pad_mode,
cfg.compress,
cfg.true_skip,
vb.pp(layer_idx),
)?;
residuals.push(resnet_block);
layer_idx += 1;
}
let downsample = StreamableConv1d::new(
mult * cfg.n_filters,
mult * cfg.n_filters * 2,
/* k_size */ ratio * 2,
/* stride */ ratio,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ true,
/* norm */ norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
layer_idx += 2;
let layer = EncoderLayer {
downsample,
residuals,
};
layers.push(layer);
mult *= 2
}
let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks {
None
} else {
Some(cfg.norm)
};
let final_conv1d = StreamableConv1d::new(
mult * cfg.n_filters,
cfg.dimension,
cfg.last_kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ final_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
Ok(Self {
init_conv1d,
activation: cfg.activation,
layers,
final_conv1d,
span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
})
}
}
impl Module for SeaNetEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.apply(&self.init_conv1d)?;
for layer in self.layers.iter() {
for residual in layer.residuals.iter() {
xs = xs.apply(residual)?
}
xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
}
xs.apply(&self.activation)?.apply(&self.final_conv1d)
}
}
impl StreamingModule for SeaNetEncoder {
fn reset_state(&mut self) {
self.init_conv1d.reset_state();
self.layers.iter_mut().for_each(|v| {
v.residuals.iter_mut().for_each(|v| v.reset_state());
v.downsample.reset_state()
});
self.final_conv1d.reset_state();
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut xs = self.init_conv1d.step(xs)?;
for layer in self.layers.iter_mut() {
for residual in layer.residuals.iter_mut() {
xs = residual.step(&xs)?;
}
xs = layer.downsample.step(&xs.apply(&self.activation)?)?;
}
self.final_conv1d.step(&xs.apply(&self.activation)?)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
upsample: StreamableConvTranspose1d,
residuals: Vec<SeaNetResnetBlock>,
}
#[derive(Debug, Clone)]
pub struct SeaNetDecoder {
init_conv1d: StreamableConv1d,
activation: candle_nn::Activation,
layers: Vec<DecoderLayer>,
final_conv1d: StreamableConv1d,
final_activation: Option<candle_nn::Activation>,
span: tracing::Span,
}
impl SeaNetDecoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.lstm > 0 {
candle::bail!("seanet lstm is not supported")
}
let n_blocks = 2 + cfg.ratios.len();
let mut mult = 1 << cfg.ratios.len();
let init_norm = if cfg.disable_norm_outer_blocks == n_blocks {
None
} else {
Some(cfg.norm)
};
let mut layer_idx = 0;
let vb = vb.pp("layers");
let init_conv1d = StreamableConv1d::new(
cfg.dimension,
mult * cfg.n_filters,
cfg.kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ init_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx),
)?;
layer_idx += 1;
let mut layers = Vec::with_capacity(cfg.ratios.len());
for (i, &ratio) in cfg.ratios.iter().enumerate() {
let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
None
} else {
Some(cfg.norm)
};
let upsample = StreamableConvTranspose1d::new(
mult * cfg.n_filters,
mult * cfg.n_filters / 2,
/* k_size */ ratio * 2,
/* stride */ ratio,
/* groups */ 1,
/* bias */ true,
/* causal */ true,
/* norm */ norm,
vb.pp(layer_idx + 1),
)?;
layer_idx += 2;
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
for j in 0..cfg.n_residual_layers {
let resnet_block = SeaNetResnetBlock::new(
mult * cfg.n_filters / 2,
&[
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
(1, 1),
],
cfg.activation,
norm,
cfg.causal,
cfg.pad_mode,
cfg.compress,
cfg.true_skip,
vb.pp(layer_idx),
)?;
residuals.push(resnet_block);
layer_idx += 1;
}
let layer = DecoderLayer {
upsample,
residuals,
};
layers.push(layer);
mult /= 2
}
let final_norm = if cfg.disable_norm_outer_blocks >= 1 {
None
} else {
Some(cfg.norm)
};
let final_conv1d = StreamableConv1d::new(
cfg.n_filters,
cfg.channels,
cfg.last_kernel_size,
/* stride */ 1,
/* dilation */ 1,
/* groups */ 1,
/* bias */ true,
/* causal */ cfg.causal,
/* norm */ final_norm,
/* pad_mode */ cfg.pad_mode,
vb.pp(layer_idx + 1),
)?;
Ok(Self {
init_conv1d,
activation: cfg.activation,
layers,
final_conv1d,
final_activation: cfg.final_activation,
span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
})
}
}
impl Module for SeaNetDecoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.apply(&self.init_conv1d)?;
for layer in self.layers.iter() {
xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
for residual in layer.residuals.iter() {
xs = xs.apply(residual)?
}
}
let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
let xs = match self.final_activation.as_ref() {
None => xs,
Some(act) => xs.apply(act)?,
};
Ok(xs)
}
}
impl StreamingModule for SeaNetDecoder {
fn reset_state(&mut self) {
self.init_conv1d.reset_state();
self.layers.iter_mut().for_each(|v| {
v.residuals.iter_mut().for_each(|v| v.reset_state());
v.upsample.reset_state()
});
self.final_conv1d.reset_state();
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let _enter = self.span.enter();
let mut xs = self.init_conv1d.step(xs)?;
for layer in self.layers.iter_mut() {
xs = layer.upsample.step(&xs.apply(&self.activation)?)?;
for residual in layer.residuals.iter_mut() {
xs = residual.step(&xs)?;
}
}
let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?;
let xs = match self.final_activation.as_ref() {
None => xs,
Some(act) => xs.apply(act)?,
};
Ok(xs)
}
}

View File

@ -0,0 +1,777 @@
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D};
use candle_nn::{linear_no_bias, Linear, VarBuilder};
use std::sync::Arc;
fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
if bias {
candle_nn::linear(in_d, out_d, vb)
} else {
linear_no_bias(in_d, out_d, vb)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum PositionalEmbedding {
Rope,
Sin,
None,
}
#[derive(Debug, Clone)]
pub struct Config {
pub d_model: usize,
pub num_heads: usize,
pub num_layers: usize,
pub causal: bool,
pub norm_first: bool,
pub bias_ff: bool,
pub bias_attn: bool,
pub layer_scale: Option<f64>,
pub positional_embedding: PositionalEmbedding,
pub use_conv_block: bool,
pub cross_attention: bool,
pub conv_kernel_size: usize,
pub use_conv_bias: bool,
pub gating: Option<candle_nn::Activation>,
pub norm: super::NormType,
pub context: usize,
pub max_period: usize,
pub max_seq_len: usize,
pub kv_repeat: usize,
pub dim_feedforward: usize,
pub conv_layout: bool,
}
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
span: tracing::Span,
}
impl RotaryEmbedding {
pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result<Self> {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
span: tracing::span!(tracing::Level::TRACE, "rot"),
})
}
pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?;
let qk_dtype = qk.dtype();
let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype)
}
}
#[derive(Debug, Clone)]
pub struct LayerScale {
scale: Tensor,
}
impl LayerScale {
pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result<Self> {
let scale = vb.get(d_model, "scale")?;
Ok(Self { scale })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.scale)
}
}
#[derive(Debug, Clone)]
pub struct StreamingMultiheadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
kv_repeat: usize,
num_heads: usize,
context: usize,
neg_inf: Tensor,
rope: Option<Arc<RotaryEmbedding>>,
kv_cache: candle_nn::kv_cache::RotatingKvCache,
pos: usize,
use_flash_attn: bool,
span: tracing::Span,
}
impl StreamingMultiheadAttention {
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_kv = cfg.num_heads / cfg.kv_repeat;
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?;
let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?;
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
rope: rope.clone(),
kv_repeat: cfg.kv_repeat,
num_heads: cfg.num_heads,
context: cfg.context,
neg_inf,
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
pos: 0,
use_flash_attn: false,
span: tracing::span!(tracing::Level::TRACE, "mha"),
})
}
pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
if self.kv_repeat != 1 {
candle::bail!("only kv-repeat = 1 is supported")
}
let (b, t, hd) = xs.dims3()?;
let head_dim = hd / self.num_heads;
let q = xs
.apply(&self.q_proj)?
.reshape((b, t, self.num_heads, head_dim))?;
let k = xs
.apply(&self.k_proj)?
.reshape((b, t, self.num_heads, head_dim))?;
let v = xs
.apply(&self.v_proj)?
.reshape((b, t, self.num_heads, head_dim))?;
// qk_layer_norm = None
// kv_repeat = 1, otherwise we would need repeat_kv
let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
if let Some(rope) = &self.rope {
q = rope.apply_rotary_emb(&q, self.pos)?;
k = rope.apply_rotary_emb(&k, self.pos)?;
}
let (k, v) = {
self.pos += k.dim(2)?;
self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?
};
// The KV cache keeps all the data at the moment, we want to trim
// down the part that comes from the cache to at most context to
// be coherent with the mask shape we provide.
let k_len = k.dim(2)?;
let k_target_len = t + usize::min(self.context, k_len - t);
let (k, v) = if k_target_len < k_len {
let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
(k, v)
} else {
(k.clone(), v.clone())
};
let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (head_dim as f32).sqrt();
flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)?
} else {
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
let pre_ws = match mask {
None => pre_ws,
Some(mask) => {
let mask = mask.broadcast_left((b, self.num_heads))?;
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
mask.where_cond(&neg_inf, &pre_ws)?
}
};
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
ws.matmul(&v)? // b,h,t,d
};
let xs = xs
.transpose(1, 2)? // b,t,h,d
.reshape((b, t, hd))?
.apply(&self.out_proj)?;
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
self.kv_cache.reset()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.kv_cache = kv_cache
}
}
#[derive(Debug, Clone)]
pub struct StreamingMultiheadCrossAttention {
in_proj_q: Linear,
in_proj_k: Linear,
in_proj_v: Linear,
out_proj: Linear,
kv_repeat: usize,
num_heads: usize,
neg_inf: Tensor,
span: tracing::Span,
}
impl StreamingMultiheadCrossAttention {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.d_model;
let num_kv = cfg.num_heads / cfg.kv_repeat;
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
let out_dim = embed_dim + 2 * kv_dim;
let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?;
let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?;
let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn {
let b = vb.get(out_dim, "in_proj_bias")?;
let q = b.narrow(0, 0, embed_dim)?;
let k = b.narrow(0, embed_dim, kv_dim)?;
let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?;
(Some(q), Some(k), Some(v))
} else {
(None, None, None)
};
let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q);
let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k);
let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
Ok(Self {
in_proj_q,
in_proj_k,
in_proj_v,
out_proj,
kv_repeat: cfg.kv_repeat,
num_heads: cfg.num_heads,
neg_inf,
span: tracing::span!(tracing::Level::TRACE, "mhca"),
})
}
pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
if self.kv_repeat != 1 {
candle::bail!("only kv-repeat = 1 is supported")
}
let (b, t, hd) = xs.dims3()?;
let head_dim = hd / self.num_heads;
// time_dim = 1, layout: b,t,h,d
let q = xs.apply(&self.in_proj_q)?;
let k = ca_src.apply(&self.in_proj_k)?;
let v = ca_src.apply(&self.in_proj_v)?;
let (ca_b, ca_t, ca_dim) = k.dims3()?;
let q = q.reshape((b, t, self.num_heads, head_dim))?;
let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
// qk_layer_norm = None
// kv_repeat = 1, otherwise we would need repeat_kv
let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
let pre_ws = match mask {
None => pre_ws,
Some(mask) => {
let mask = mask.broadcast_left((b, self.num_heads))?;
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
mask.where_cond(&neg_inf, &pre_ws)?
}
};
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
let xs = ws.matmul(&v)?; // b,h,t,d
let xs = xs
.transpose(1, 2)? // b,t,h,d
.reshape((b, t, hd))?
.apply(&self.out_proj)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub enum Mlp {
NoGating {
span1: tracing::Span,
linear1: Linear,
span2: tracing::Span,
linear2: Linear,
span: tracing::Span,
},
Gating {
linear_in: Linear,
linear_out: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
},
}
impl Mlp {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let d_model = cfg.d_model;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
match cfg.gating {
None => {
let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?;
let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?;
Ok(Self::NoGating {
linear1,
linear2,
span,
span1,
span2,
})
}
Some(activation) => {
let vb = vb.pp("gating");
let hidden = if cfg.dim_feedforward == 4 * d_model {
11 * d_model / 4
} else {
2 * cfg.dim_feedforward / 3
};
// TODO: Maybe use bias_ff here?
let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
Ok(Self::Gating {
linear_in,
linear_out,
activation,
span,
})
}
}
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::NoGating {
linear1,
linear2,
span,
span1,
span2,
} => {
let _enter = span.enter();
let xs = {
let _enter = span1.enter();
xs.apply(linear1)?
};
let xs = xs.gelu_erf()?;
{
let _enter = span2.enter();
xs.apply(linear2)
}
}
Self::Gating {
linear_in,
linear_out,
activation,
span,
} => {
let _enter = span.enter();
let xs = xs.apply(linear_in)?;
let (b, t, _) = xs.dims3()?;
let xs = xs.reshape((b, t, 2, ()))?;
let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
xs.apply(linear_out)
}
}
}
}
#[derive(Debug, Clone)]
pub struct RmsNorm {
pub(crate) alpha: Tensor,
pub(crate) eps: f32,
}
impl RmsNorm {
pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?;
Ok(Self { alpha, eps })
}
}
impl Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
}
}
#[derive(Debug, Clone)]
pub enum Norm {
LayerNorm(candle_nn::LayerNorm),
RmsNorm(RmsNorm),
}
impl Norm {
pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm = match cfg.norm {
super::NormType::LayerNorm => {
let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
Self::LayerNorm(norm)
}
super::NormType::RmsNorm => {
let norm = RmsNorm::new(d_model, 1e-8, vb)?;
Self::RmsNorm(norm)
}
};
Ok(norm)
}
}
impl Module for Norm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::LayerNorm(m) => m.forward(xs),
Self::RmsNorm(m) => m.forward(xs),
}
}
}
#[derive(Debug, Clone)]
pub struct StreamingTransformerLayer {
self_attn: StreamingMultiheadAttention,
mlp: Mlp,
norm1: Norm,
norm2: Norm,
layer_scale_1: Option<LayerScale>,
layer_scale_2: Option<LayerScale>,
cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
norm_first: bool,
span: tracing::Span,
}
impl StreamingTransformerLayer {
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
if cfg.use_conv_block {
candle::bail!("conv-block is not supported")
}
let d_model = cfg.d_model;
let mlp = Mlp::new(cfg, vb.clone())?;
let (norm1, norm2) = match cfg.norm {
super::NormType::LayerNorm => {
let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?;
let norm2 =
candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?;
(Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
}
super::NormType::RmsNorm => {
let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?;
let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?;
(Norm::RmsNorm(norm1), Norm::RmsNorm(norm2))
}
};
let layer_scale_1 = match cfg.layer_scale {
None => None,
Some(ls) => {
let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?;
Some(ls)
}
};
let layer_scale_2 = match cfg.layer_scale {
None => None,
Some(ls) => {
let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?;
Some(ls)
}
};
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
let cross_attn = if cfg.cross_attention {
let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
Some((norm_cross, cross_attn))
} else {
None
};
Ok(Self {
self_attn,
mlp,
norm1,
norm2,
layer_scale_1,
layer_scale_2,
cross_attn,
norm_first: cfg.norm_first,
span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
})
}
pub fn forward(
&mut self,
xs: &Tensor,
ca_src: Option<&Tensor>,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
if !self.norm_first {
candle::bail!("only norm_first = true is supported")
}
let norm1 = xs.apply(&self.norm1)?;
let xs = (xs
+ self
.self_attn
.forward(&norm1, mask)?
.apply(&self.layer_scale_1.as_ref())?)?;
let xs = match (&self.cross_attn, ca_src) {
(Some((norm_cross, cross_attn)), Some(ca_src)) => {
let residual = &xs;
let xs = xs.apply(norm_cross)?;
(residual + cross_attn.forward(&xs, ca_src, None)?)?
}
_ => xs,
};
let xs = (&xs
+ xs.apply(&self.norm2)?
.apply(&self.mlp)?
.apply(&self.layer_scale_2.as_ref()))?;
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache()
}
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.self_attn.set_kv_cache(kv_cache)
}
}
#[derive(Debug, Clone)]
pub struct StreamingTransformer {
layers: Vec<StreamingTransformerLayer>,
positional_embedding: PositionalEmbedding,
max_period: usize,
}
impl StreamingTransformer {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_l = vb.pp("layers");
let rope = match cfg.positional_embedding {
PositionalEmbedding::Rope => {
let rope = RotaryEmbedding::new(
cfg.d_model / cfg.num_heads,
cfg.max_seq_len,
cfg.max_period as f32,
vb.device(),
)?;
Some(Arc::new(rope))
}
PositionalEmbedding::Sin | PositionalEmbedding::None => None,
};
let mut layers = Vec::with_capacity(cfg.num_layers);
for layer_idx in 0..cfg.num_layers {
let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
Ok(Self {
layers,
positional_embedding: cfg.positional_embedding,
max_period: cfg.max_period,
})
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
self.forward_ca(xs, None)
}
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
let (_b, t, c) = xs.dims3()?;
let pos = self.layers[0].self_attn.kv_cache.current_seq_len();
let mask = self.layers[0]
.self_attn
.kv_cache
.attn_mask(t, xs.device())?;
let mut xs = match self.positional_embedding {
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
PositionalEmbedding::Sin => {
let dev = xs.device();
let theta = self.max_period as f32;
let half_dim = c / 2;
let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)?
.unsqueeze(1)?
.to_dtype(DType::F32)?;
let inv_freq: Vec<_> = (0..half_dim)
.map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let freqs = positions.broadcast_mul(&inv_freq)?;
let pos_emb =
Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?;
xs.broadcast_add(&pos_emb)?
}
};
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, ca_src, mask.as_ref())?;
}
Ok(xs)
}
pub fn copy_state(&mut self, from: &Self) -> Result<()> {
if self.layers.len() != from.layers.len() {
candle::bail!("cannot copy kv-caches as the transformers have different depths")
}
self.layers
.iter_mut()
.zip(from.layers.iter())
.for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
Ok(())
}
}
impl StreamingModule for StreamingTransformer {
fn reset_state(&mut self) {
self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
match xs.as_option() {
None => Ok(StreamTensor::empty()),
Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
}
}
}
#[derive(Debug, Clone)]
pub struct ProjectedTransformer {
transformer: StreamingTransformer,
input_proj: Option<Linear>,
output_projs: Vec<Option<Linear>>,
conv_layout: bool,
span: tracing::Span,
}
impl ProjectedTransformer {
pub fn new(
input_dim: usize,
output_dims: &[usize],
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let transformer = StreamingTransformer::new(cfg, vb.clone())?;
let input_proj = if input_dim == cfg.d_model {
None
} else {
let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?;
Some(l)
};
let mut output_projs = Vec::with_capacity(output_dims.len());
let vb_o = vb.pp("output_projs");
for (i, &output_dim) in output_dims.iter().enumerate() {
let output_proj = if output_dim == cfg.d_model {
None
} else {
let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?;
Some(l)
};
output_projs.push(output_proj)
}
Ok(Self {
transformer,
input_proj,
output_projs,
conv_layout: cfg.conv_layout,
span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
})
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
let _enter = self.span.enter();
let xs = if self.conv_layout {
xs.transpose(1, 2)?
} else {
xs.clone()
};
let xs = xs.apply(&self.input_proj.as_ref())?;
let xs = self.transformer.forward(&xs)?;
let mut ys = Vec::with_capacity(self.output_projs.len());
for output_proj in self.output_projs.iter() {
let ys_ = xs.apply(&output_proj.as_ref())?;
let ys_ = if self.conv_layout {
ys_.transpose(1, 2)?
} else {
ys_
};
ys.push(ys_)
}
Ok(ys)
}
}
impl StreamingModule for ProjectedTransformer {
fn reset_state(&mut self) {
self.transformer.reset_state()
}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
let xs = xs.apply(&|x: &Tensor| {
if self.conv_layout {
x.transpose(1, 2)
} else {
Ok(x.clone())
}
})?;
let xs = xs.apply(&self.input_proj.as_ref())?;
let xs = self.transformer.step(&xs)?;
let ys = xs.apply(&self.output_projs[0].as_ref())?;
ys.apply(&|y: &Tensor| {
if self.conv_layout {
y.transpose(1, 2)
} else {
Ok(y.clone())
}
})
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}

View File

@ -0,0 +1,294 @@
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections};
pub struct ModulateIntermediates {
gate_msa: Tensor,
shift_mlp: Tensor,
scale_mlp: Tensor,
gate_mlp: Tensor,
}
pub struct DiTBlock {
norm1: LayerNormNoAffine,
attn: AttnProjections,
norm2: LayerNormNoAffine,
mlp: Mlp,
ada_ln_modulation: nn::Sequential,
}
pub struct LayerNormNoAffine {
eps: f64,
}
impl LayerNormNoAffine {
pub fn new(eps: f64) -> Self {
Self { eps }
}
}
impl Module for LayerNormNoAffine {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x)
}
}
impl DiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
// {'hidden_size': 1536, 'num_heads': 24}
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let norm2 = LayerNormNoAffine::new(1e-6);
let mlp_ratio = 4;
let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
let n_mods = 6;
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
hidden_size,
n_mods * hidden_size,
vb.pp("adaLN_modulation.1"),
)?);
Ok(Self {
norm1,
attn,
norm2,
mlp,
ada_ln_modulation,
})
}
pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> {
let modulation = self.ada_ln_modulation.forward(c)?;
let chunks = modulation.chunk(6, D::Minus1)?;
let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
chunks[0].clone(),
chunks[1].clone(),
chunks[2].clone(),
chunks[3].clone(),
chunks[4].clone(),
chunks[5].clone(),
);
let norm_x = self.norm1.forward(x)?;
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
let qkv = self.attn.pre_attention(&modulated_x)?;
Ok((
qkv,
ModulateIntermediates {
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
},
))
}
pub fn post_attention(
&self,
attn: &Tensor,
x: &Tensor,
mod_interm: &ModulateIntermediates,
) -> Result<Tensor> {
let attn_out = self.attn.post_attention(attn)?;
let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
let norm_x = self.norm2.forward(&x)?;
let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
let mlp_out = self.mlp.forward(&modulated_x)?;
let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
Ok(x)
}
}
pub struct QkvOnlyDiTBlock {
norm1: LayerNormNoAffine,
attn: QkvOnlyAttnProjections,
ada_ln_modulation: nn::Sequential,
}
impl QkvOnlyDiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let n_mods = 2;
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
hidden_size,
n_mods * hidden_size,
vb.pp("adaLN_modulation.1"),
)?);
Ok(Self {
norm1,
attn,
ada_ln_modulation,
})
}
pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<Qkv> {
let modulation = self.ada_ln_modulation.forward(c)?;
let chunks = modulation.chunk(2, D::Minus1)?;
let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone());
let norm_x = self.norm1.forward(x)?;
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
self.attn.pre_attention(&modulated_x)
}
}
pub struct FinalLayer {
norm_final: LayerNormNoAffine,
linear: nn::Linear,
ada_ln_modulation: nn::Sequential,
}
impl FinalLayer {
pub fn new(
hidden_size: usize,
patch_size: usize,
out_channels: usize,
vb: nn::VarBuilder,
) -> Result<Self> {
let norm_final = LayerNormNoAffine::new(1e-6);
let linear = nn::linear(
hidden_size,
patch_size * patch_size * out_channels,
vb.pp("linear"),
)?;
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
hidden_size,
2 * hidden_size,
vb.pp("adaLN_modulation.1"),
)?);
Ok(Self {
norm_final,
linear,
ada_ln_modulation,
})
}
pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result<Tensor> {
let modulation = self.ada_ln_modulation.forward(c)?;
let chunks = modulation.chunk(2, D::Minus1)?;
let (shift, scale) = (chunks[0].clone(), chunks[1].clone());
let norm_x = self.norm_final.forward(x)?;
let modulated_x = modulate(&norm_x, &shift, &scale)?;
let output = self.linear.forward(&modulated_x)?;
Ok(output)
}
}
fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
let shift = shift.unsqueeze(1)?;
let scale = scale.unsqueeze(1)?;
let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?;
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
}
pub struct JointBlock {
x_block: DiTBlock,
context_block: DiTBlock,
num_heads: usize,
}
impl JointBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
Ok(Self {
x_block,
context_block,
num_heads,
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
let context_out =
self.context_block
.post_attention(&context_attn, context, &context_interm)?;
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
Ok((context_out, x_out))
}
}
pub struct ContextQkvOnlyJointBlock {
x_block: DiTBlock,
context_block: QkvOnlyDiTBlock,
num_heads: usize,
}
impl ContextQkvOnlyJointBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
Ok(Self {
x_block,
context_block,
num_heads,
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
let context_qkv = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
Ok(x_out)
}
}
// A QKV-attention that is compatible with the interface of candle_flash_attn::flash_attn
// Flash attention regards q, k, v dimensions as (batch_size, seqlen, nheads, headdim)
fn flash_compatible_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
) -> Result<Tensor> {
let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec();
let rank = q_dims_for_matmul.len();
let q = q.transpose(1, 2)?.flatten_to(rank - 3)?;
let k = k.transpose(1, 2)?.flatten_to(rank - 3)?;
let v = v.transpose(1, 2)?.flatten_to(rank - 3)?;
let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
}
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
let qkv = Qkv {
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
};
let (batch_size, seqlen, _) = qkv.q.dims3()?;
let qkv = Qkv {
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
v: qkv.v,
};
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
let attn = attn.reshape((batch_size, seqlen, ()))?;
let context_qkv_seqlen = context_qkv.q.dim(1)?;
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
Ok((context_attn, x_attn))
}

View File

@ -0,0 +1,197 @@
use candle::{bail, DType, Module, Result, Tensor};
use candle_nn as nn;
pub struct PatchEmbedder {
proj: nn::Conv2d,
}
impl PatchEmbedder {
pub fn new(
patch_size: usize,
in_channels: usize,
embed_dim: usize,
vb: nn::VarBuilder,
) -> Result<Self> {
let proj = nn::conv2d(
in_channels,
embed_dim,
patch_size,
nn::Conv2dConfig {
stride: patch_size,
..Default::default()
},
vb.pp("proj"),
)?;
Ok(Self { proj })
}
}
impl Module for PatchEmbedder {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.proj.forward(x)?;
// flatten spatial dim and transpose to channels last
let (b, c, h, w) = x.dims4()?;
x.reshape((b, c, h * w))?.transpose(1, 2)
}
}
pub struct Unpatchifier {
patch_size: usize,
out_channels: usize,
}
impl Unpatchifier {
pub fn new(patch_size: usize, out_channels: usize) -> Result<Self> {
Ok(Self {
patch_size,
out_channels,
})
}
pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result<Tensor> {
let h = (h + 1) / self.patch_size;
let w = (w + 1) / self.patch_size;
let x = x.reshape((
x.dim(0)?,
h,
w,
self.patch_size,
self.patch_size,
self.out_channels,
))?;
let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq"
x.reshape((
x.dim(0)?,
self.out_channels,
self.patch_size * h,
self.patch_size * w,
))
}
}
pub struct PositionEmbedder {
pos_embed: Tensor,
patch_size: usize,
pos_embed_max_size: usize,
}
impl PositionEmbedder {
pub fn new(
hidden_size: usize,
patch_size: usize,
pos_embed_max_size: usize,
vb: nn::VarBuilder,
) -> Result<Self> {
let pos_embed = vb.get(
(1, pos_embed_max_size * pos_embed_max_size, hidden_size),
"pos_embed",
)?;
Ok(Self {
pos_embed,
patch_size,
pos_embed_max_size,
})
}
pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result<Tensor> {
let h = (h + 1) / self.patch_size;
let w = (w + 1) / self.patch_size;
if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
bail!("Input size is too large for the position embedding")
}
let top = (self.pos_embed_max_size - h) / 2;
let left = (self.pos_embed_max_size - w) / 2;
let pos_embed =
self.pos_embed
.reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
pos_embed.reshape((1, h * w, ()))
}
}
pub struct TimestepEmbedder {
mlp: nn::Sequential,
frequency_embedding_size: usize,
}
impl TimestepEmbedder {
pub fn new(
hidden_size: usize,
frequency_embedding_size: usize,
vb: nn::VarBuilder,
) -> Result<Self> {
let mlp = nn::seq()
.add(nn::linear(
frequency_embedding_size,
hidden_size,
vb.pp("mlp.0"),
)?)
.add(nn::Activation::Silu)
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
Ok(Self {
mlp,
frequency_embedding_size,
})
}
fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result<Tensor> {
if dim % 2 != 0 {
bail!("Embedding dimension must be even")
}
if t.dtype() != DType::F32 && t.dtype() != DType::F64 {
bail!("Input tensor must be floating point")
}
let half = dim / 2;
let freqs = Tensor::arange(0f32, half as f32, t.device())?
.to_dtype(candle::DType::F32)?
.mul(&Tensor::full(
(-f64::ln(max_period) / half as f64) as f32,
half,
t.device(),
)?)?
.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(candle::DType::F32)?
.matmul(&freqs.unsqueeze(0)?)?;
let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
embedding.to_dtype(candle::DType::F16)
}
}
impl Module for TimestepEmbedder {
fn forward(&self, t: &Tensor) -> Result<Tensor> {
let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?;
self.mlp.forward(&t_freq)
}
}
pub struct VectorEmbedder {
mlp: nn::Sequential,
}
impl VectorEmbedder {
pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result<Self> {
let mlp = nn::seq()
.add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?)
.add(nn::Activation::Silu)
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
Ok(Self { mlp })
}
}
impl Module for VectorEmbedder {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.mlp.forward(x)
}
}

View File

@ -0,0 +1,4 @@
pub mod blocks;
pub mod embedding;
pub mod model;
pub mod projections;

View File

@ -0,0 +1,173 @@
// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
// This follows the implementation of the MMDiT model in the ComfyUI repository.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
use super::embedding::{
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
};
#[derive(Debug, Clone)]
pub struct Config {
pub patch_size: usize,
pub in_channels: usize,
pub out_channels: usize,
pub depth: usize,
pub head_size: usize,
pub adm_in_channels: usize,
pub pos_embed_max_size: usize,
pub context_embed_size: usize,
pub frequency_embedding_size: usize,
}
impl Config {
pub fn sd3() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 24,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
}
pub struct MMDiT {
core: MMDiTCore,
patch_embedder: PatchEmbedder,
pos_embedder: PositionEmbedder,
timestep_embedder: TimestepEmbedder,
vector_embedder: VectorEmbedder,
context_embedder: nn::Linear,
unpatchifier: Unpatchifier,
}
impl MMDiT {
pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
let hidden_size = cfg.head_size * cfg.depth;
let core = MMDiTCore::new(
cfg.depth,
hidden_size,
cfg.depth,
cfg.patch_size,
cfg.out_channels,
vb.clone(),
)?;
let patch_embedder = PatchEmbedder::new(
cfg.patch_size,
cfg.in_channels,
hidden_size,
vb.pp("x_embedder"),
)?;
let pos_embedder = PositionEmbedder::new(
hidden_size,
cfg.patch_size,
cfg.pos_embed_max_size,
vb.clone(),
)?;
let timestep_embedder = TimestepEmbedder::new(
hidden_size,
cfg.frequency_embedding_size,
vb.pp("t_embedder"),
)?;
let vector_embedder =
VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
let context_embedder = nn::linear(
cfg.context_embed_size,
hidden_size,
vb.pp("context_embedder"),
)?;
let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
Ok(Self {
core,
patch_embedder,
pos_embedder,
timestep_embedder,
vector_embedder,
context_embedder,
unpatchifier,
})
}
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// t: (N,) tensor of diffusion timesteps
// y: (N,) tensor of class labels
let h = x.dim(D::Minus2)?;
let w = x.dim(D::Minus1)?;
let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
let x = self
.patch_embedder
.forward(x)?
.broadcast_add(&cropped_pos_embed)?;
let c = self.timestep_embedder.forward(t)?;
let y = self.vector_embedder.forward(y)?;
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
let x = self.core.forward(&context, &x, &c)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
}
pub struct MMDiTCore {
joint_blocks: Vec<JointBlock>,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
final_layer: FinalLayer,
}
impl MMDiTCore {
pub fn new(
depth: usize,
hidden_size: usize,
num_heads: usize,
patch_size: usize,
out_channels: usize,
vb: nn::VarBuilder,
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
for i in 0..depth - 1 {
joint_blocks.push(JointBlock::new(
hidden_size,
num_heads,
vb.pp(format!("joint_blocks.{}", i)),
)?);
}
Ok(Self {
joint_blocks,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
hidden_size,
num_heads,
vb.pp(format!("joint_blocks.{}", depth - 1)),
)?,
final_layer: FinalLayer::new(
hidden_size,
patch_size,
out_channels,
vb.pp("final_layer"),
)?,
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
for joint_block in &self.joint_blocks {
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
self.final_layer.forward(&x, c)
}
}

View File

@ -0,0 +1,94 @@
use candle::{Module, Result, Tensor};
use candle_nn as nn;
pub struct Qkv {
pub q: Tensor,
pub k: Tensor,
pub v: Tensor,
}
pub struct Mlp {
fc1: nn::Linear,
act: nn::Activation,
fc2: nn::Linear,
}
impl Mlp {
pub fn new(
in_features: usize,
hidden_features: usize,
vb: candle_nn::VarBuilder,
) -> Result<Self> {
let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
let act = nn::Activation::GeluPytorchTanh;
let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?;
Ok(Self { fc1, act, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.fc1.forward(x)?;
let x = self.act.forward(&x)?;
self.fc2.forward(&x)
}
}
pub struct QkvOnlyAttnProjections {
qkv: nn::Linear,
head_dim: usize,
}
impl QkvOnlyAttnProjections {
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
// {'dim': 1536, 'num_heads': 24}
let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
Ok(Self { qkv, head_dim })
}
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
let qkv = self.qkv.forward(x)?;
split_qkv(&qkv, self.head_dim)
}
}
pub struct AttnProjections {
head_dim: usize,
qkv: nn::Linear,
proj: nn::Linear,
}
impl AttnProjections {
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
Ok(Self {
head_dim,
qkv,
proj,
})
}
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
let qkv = self.qkv.forward(x)?;
split_qkv(&qkv, self.head_dim)
}
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
self.proj.forward(x)
}
}
fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result<Qkv> {
let (batch_size, seq_len, _) = qkv.dims3()?;
let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
let q = qkv.get_on_dim(2, 0)?;
let q = q.reshape((batch_size, seq_len, ()))?;
let k = qkv.get_on_dim(2, 1)?;
let k = k.reshape((batch_size, seq_len, ()))?;
let v = qkv.get_on_dim(2, 2)?;
Ok(Qkv { q, k, v })
}

View File

@ -0,0 +1,89 @@
use super::fastvit;
use super::openclip::text_model;
use candle::{Result, Tensor, D};
use candle_nn::{Func, VarBuilder};
#[derive(Clone, Debug)]
pub struct MobileClipModel {
text_model: text_model::OpenClipTextTransformer,
vision_model: Func<'static>,
text_projection: Tensor,
logit_scale: Tensor,
}
#[derive(Clone, Debug)]
pub struct MobileClipConfig {
pub text_config: text_model::Config,
pub vision_config: fastvit::Config,
pub image_size: usize,
}
impl MobileClipConfig {
pub fn s1() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci1();
Self {
text_config,
vision_config,
image_size: 256,
}
}
pub fn s2() -> Self {
let text_config = text_model::Config::vit_base_patch32();
let vision_config = fastvit::Config::mci2();
Self {
text_config,
vision_config,
image_size: 256,
}
}
}
impl MobileClipModel {
pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
let text_projection = vs.get(
(c.text_config.embed_dim, c.text_config.projection_dim),
"text.text_projection",
)?;
let logit_scale = vs.get(&[], "logit_scale")?;
Ok(Self {
text_model,
vision_model,
text_projection,
logit_scale,
})
}
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
input_ids
.apply(&self.text_model)?
.matmul(&self.text_projection)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
pixel_values.apply(&self.vision_model)
}
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
v.broadcast_div(&l2_norm)
}

View File

@ -1,3 +1,4 @@
pub mod based;
pub mod beit;
pub mod bert;
pub mod bigcode;
@ -8,6 +9,7 @@ pub mod clip;
pub mod codegeex4_9b;
pub mod convmixer;
pub mod convnext;
pub mod dac;
pub mod depth_anything_v2;
pub mod dinov2;
pub mod dinov2reg4;
@ -17,8 +19,12 @@ pub mod efficientvit;
pub mod encodec;
pub mod eva2;
pub mod falcon;
pub mod fastvit;
pub mod flux;
pub mod gemma;
pub mod gemma2;
pub mod glm4;
pub mod granite;
pub mod hiera;
pub mod jina_bert;
pub mod llama;
@ -28,14 +34,19 @@ pub mod llava;
pub mod mamba;
pub mod marian;
pub mod metavoice;
pub mod mimi;
pub mod mistral;
pub mod mixformer;
pub mod mixtral;
pub mod mmdit;
pub mod mobileclip;
pub mod mobilenetv4;
pub mod mobileone;
pub mod moondream;
pub mod mpt;
pub mod olmo;
pub mod openclip;
pub mod parler_tts;
pub mod persimmon;
pub mod phi;
pub mod phi3;

View File

@ -167,7 +167,7 @@ impl VisionTransformer {
let blocks = (0..cfg.num_blocks)
.map(|i| {
VitBlock::new(
vb.pp(&format!("blocks.{}", i)),
vb.pp(format!("blocks.{}", i)),
cfg.embed_dim,
cfg.num_heads,
cfg,

View File

@ -0,0 +1 @@
pub mod text_model;

View File

@ -0,0 +1,266 @@
//! Text encoder as used in most OpenCLIP pretrained models
//! https://github.com/mlfoundations/open_clip
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
#[derive(Debug, Clone)]
pub struct Config {
pub vocab_size: usize,
pub embed_dim: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub pad_with: Option<String>,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub projection_dim: usize,
}
impl Config {
pub fn vit_base_patch32() -> Self {
Self {
vocab_size: 49408,
embed_dim: 512,
intermediate_size: 2048,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 8,
projection_dim: 512,
}
}
}
#[derive(Clone, Debug)]
struct TextEmbeddings {
token_embedding: Embedding,
position_embedding: Tensor,
}
impl TextEmbeddings {
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding = vs.get(
(c.max_position_embeddings, c.embed_dim),
"positional_embedding",
)?;
Ok(TextEmbeddings {
token_embedding,
position_embedding,
})
}
}
impl Module for TextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?;
inputs_embeds.broadcast_add(&position_embedding)
}
}
#[derive(Clone, Debug)]
struct Attention {
k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear,
q_proj: candle_nn::Linear,
out_proj: Linear,
head_dim: usize,
scale: f64,
num_attention_heads: usize,
}
impl Attention {
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
let embed_dim = c.embed_dim;
let num_attention_heads = c.num_attention_heads;
let in_proj_weights = vs
.get((embed_dim * 3, embed_dim), "in_proj_weight")?
.chunk(3, 0)?;
let (q_w, k_w, v_w) = (
&in_proj_weights[0],
&in_proj_weights[1],
&in_proj_weights[2],
);
let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?;
let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]);
let q_proj = Linear::new(q_w.clone(), Some(q_b.clone()));
let k_proj = Linear::new(k_w.clone(), Some(k_b.clone()));
let v_proj = Linear::new(v_w.clone(), Some(v_b.clone()));
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
let head_dim = embed_dim / num_attention_heads;
let scale = (head_dim as f64).powf(-0.5);
Ok(Attention {
k_proj,
v_proj,
q_proj,
out_proj,
head_dim,
scale,
num_attention_heads,
})
}
fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result<Tensor> {
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?
.to_dtype(DType::F32)
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let in_dtype = xs.dtype();
let (bsz, seq_len, embed_dim) = xs.dims3()?;
let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?;
let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?;
let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?;
let q = (q * self.scale)?;
let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
let attn_weights = softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?;
let attn_output = attn_output
.transpose(1, 2)?
.contiguous()?
.reshape((bsz, seq_len, embed_dim))?;
let out = self.out_proj.forward(&attn_output)?;
Ok(out)
}
}
#[derive(Clone, Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?;
let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?;
Ok(Mlp { fc1, fc2 })
}
}
impl Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?;
self.fc2.forward(&xs.gelu_erf()?)
}
}
#[derive(Clone, Debug)]
struct EncoderLayer {
self_attn: Attention,
layer_norm1: LayerNorm,
mlp: Mlp,
layer_norm2: LayerNorm,
}
impl EncoderLayer {
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
let self_attn = Attention::new(vs.pp("attn"), c)?;
let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?;
let mlp = Mlp::new(vs.pp("mlp"), c)?;
let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?;
Ok(EncoderLayer {
self_attn,
layer_norm1,
mlp,
layer_norm2,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self.layer_norm1.forward(xs)?;
let xs = self.self_attn.forward(&xs)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self.layer_norm2.forward(&xs)?;
let xs = self.mlp.forward(&xs)?;
let out = (xs + residual)?;
Ok(out)
}
}
#[derive(Clone, Debug)]
pub struct Encoder {
layers: Vec<EncoderLayer>,
}
impl Encoder {
pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
let vs = vs.pp("resblocks");
let mut layers: Vec<EncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers {
let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?;
layers.push(layer)
}
Ok(Encoder { layers })
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs)?;
}
Ok(xs)
}
}
/// A text transformer as used in CLIP variants.
#[derive(Clone, Debug)]
pub struct OpenClipTextTransformer {
embeddings: TextEmbeddings,
encoder: Encoder,
final_layer_norm: LayerNorm,
}
impl OpenClipTextTransformer {
pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
let embeddings = TextEmbeddings::new(vs.clone(), c)?;
let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?;
let encoder = Encoder::new(vs.pp("transformer"), c)?;
Ok(OpenClipTextTransformer {
embeddings,
encoder,
final_layer_norm,
})
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_ids = self.embeddings.forward(input_ids)?;
let input_ids = self.encoder.forward(&input_ids)?;
self.final_layer_norm.forward(&input_ids)
}
}
impl Module for OpenClipTextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let output = self.forward(input_ids)?;
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
let mut indices = Vec::new();
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
indices.push(index);
}
Tensor::cat(&indices, 0)
}
}

View File

@ -0,0 +1,456 @@
use crate::generation::LogitsProcessor;
use crate::models::t5;
use candle::{IndexOp, Result, Tensor};
use candle_nn::{layer_norm, linear_b as linear, Activation, LayerNorm, Linear, VarBuilder};
#[derive(serde::Deserialize, Debug, Clone)]
pub struct DecoderConfig {
pub vocab_size: usize,
pub max_position_embeddings: usize,
pub num_hidden_layers: usize,
pub ffn_dim: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: Option<usize>,
pub num_cross_attention_key_value_heads: Option<usize>,
pub activation_function: Activation,
pub hidden_size: usize,
pub scale_embedding: bool,
pub num_codebooks: usize,
pub pad_token_id: usize,
pub bos_token_id: usize,
pub eos_token_id: usize,
pub tie_word_embeddings: bool,
pub rope_embeddings: bool,
pub rope_theta: f64,
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub decoder_start_token_id: u32,
pub pad_token_id: u32,
pub decoder: DecoderConfig,
pub text_encoder: t5::Config,
pub vocab_size: usize,
pub audio_encoder: crate::models::dac::Config,
}
#[derive(Debug, Clone)]
pub struct Attention {
k_proj: Linear,
v_proj: Linear,
q_proj: Linear,
out_proj: Linear,
is_causal: bool,
kv_cache: Option<(Tensor, Tensor)>,
scaling: f64,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
}
impl Attention {
fn new(
num_kv_heads: usize,
is_causal: bool,
cfg: &DecoderConfig,
vb: VarBuilder,
) -> Result<Self> {
if cfg.rope_embeddings {
candle::bail!("rope embeddings are not supported");
}
let embed_dim = cfg.hidden_size;
let head_dim = embed_dim / cfg.num_attention_heads;
let kv_out_dim = num_kv_heads * head_dim;
let k_proj = linear(embed_dim, kv_out_dim, false, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, kv_out_dim, false, vb.pp("v_proj"))?;
let q_proj = linear(embed_dim, embed_dim, false, vb.pp("q_proj"))?;
let out_proj = linear(embed_dim, embed_dim, false, vb.pp("out_proj"))?;
Ok(Self {
k_proj,
v_proj,
q_proj,
out_proj,
is_causal,
kv_cache: None,
scaling: (head_dim as f64).powf(-0.5),
num_heads: cfg.num_attention_heads,
num_kv_heads,
num_kv_groups: cfg.num_attention_heads / num_kv_heads,
head_dim,
})
}
fn forward(
&mut self,
xs: &Tensor,
key_value_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?
.reshape((b_sz, tgt_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let key_states = match key_value_states {
Some(states) => states.apply(&self.k_proj)?,
None => xs.apply(&self.k_proj)?,
};
let key_states = key_states
.reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let value_states = match key_value_states {
Some(states) => states.apply(&self.v_proj)?,
None => xs.apply(&self.v_proj)?,
};
let value_states = value_states
.reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
if self.is_causal {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&value_states)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, tgt_len, ()))?
.apply(&self.out_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
pub struct DecoderLayer {
self_attn: Attention,
self_attn_layer_norm: LayerNorm,
encoder_attn: Attention,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
activation: Activation,
}
impl DecoderLayer {
fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
let kv_heads = cfg.num_key_value_heads.unwrap_or(cfg.num_attention_heads);
let kv_heads_cross = cfg.num_cross_attention_key_value_heads.unwrap_or(kv_heads);
let self_attn = Attention::new(kv_heads, true, cfg, vb.pp("self_attn"))?;
let encoder_attn = Attention::new(kv_heads_cross, false, cfg, vb.pp("encoder_attn"))?;
let self_attn_layer_norm =
layer_norm(cfg.hidden_size, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn_layer_norm =
layer_norm(cfg.hidden_size, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear(cfg.hidden_size, cfg.ffn_dim, false, vb.pp("fc1"))?;
let fc2 = linear(cfg.ffn_dim, cfg.hidden_size, false, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
self_attn_layer_norm,
encoder_attn,
encoder_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
activation: cfg.activation_function,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
encoder_xs: &Tensor,
encoder_attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
// Self attention
let residual = xs;
let xs = xs.apply(&self.self_attn_layer_norm)?;
let xs = self.self_attn.forward(&xs, None, attention_mask)?;
let xs = (residual + xs)?;
// Cross attention
let residual = &xs;
let xs = xs.apply(&self.encoder_attn_layer_norm)?;
let xs = self
.encoder_attn
.forward(&xs, Some(encoder_xs), encoder_attention_mask)?;
let xs = (residual + xs)?;
// Fully connected
let residual = &xs;
let xs = xs
.apply(&self.final_layer_norm)?
.apply(&self.fc1)?
.apply(&self.activation)?
.apply(&self.fc2)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
self.encoder_attn.clear_kv_cache();
}
}
#[derive(Debug, Clone)]
pub struct Decoder {
embed_tokens: Vec<candle_nn::Embedding>,
embed_positions: Tensor,
layers: Vec<DecoderLayer>,
layer_norm: LayerNorm,
num_codebooks: usize,
hidden_size: usize,
lm_heads: Vec<Linear>,
dtype: candle::DType,
}
impl Decoder {
pub fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
let vb_d = vb.pp("model.decoder");
let mut embed_tokens = Vec::with_capacity(cfg.num_codebooks);
let vb_e = vb_d.pp("embed_tokens");
for embed_idx in 0..cfg.num_codebooks {
let e = candle_nn::embedding(cfg.vocab_size + 1, cfg.hidden_size, vb_e.pp(embed_idx))?;
embed_tokens.push(e)
}
let embed_positions = vb_d.get(
(cfg.max_position_embeddings, cfg.hidden_size),
"embed_positions.weights",
)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_d.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb_d.pp("layer_norm"))?;
let mut lm_heads = Vec::with_capacity(cfg.num_codebooks);
let vb_l = vb.pp("lm_heads");
for lm_idx in 0..cfg.num_codebooks {
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb_l.pp(lm_idx))?;
lm_heads.push(lm_head)
}
Ok(Self {
embed_tokens,
embed_positions,
layers,
layer_norm,
num_codebooks: cfg.num_codebooks,
lm_heads,
hidden_size: cfg.hidden_size,
dtype: vb.dtype(),
})
}
pub fn forward(
&mut self,
input_ids: &Tensor,
prompt_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_xs: &Tensor,
encoder_attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Vec<Tensor>> {
let (b_sz, num_codebooks, seq_len) = input_ids.dims3()?;
if num_codebooks != self.num_codebooks {
candle::bail!("unexpected num codebooks in input {:?}", input_ids.shape())
}
let mut inputs_embeds = Tensor::zeros(
(b_sz, seq_len, self.hidden_size),
self.dtype,
input_ids.device(),
)?;
for (idx, embs) in self.embed_tokens.iter().enumerate() {
let e = input_ids.i((.., idx))?.apply(embs)?;
inputs_embeds = (inputs_embeds + e)?
}
let inputs_embeds = match prompt_hidden_states {
None => inputs_embeds,
Some(pis) => Tensor::cat(&[pis, &inputs_embeds], 1)?,
};
let embed_positions = self
.embed_positions
.i(seqlen_offset..seqlen_offset + inputs_embeds.dim(1)?)?;
let mut xs = (inputs_embeds + embed_positions.unsqueeze(0))?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask, encoder_xs, encoder_attention_mask)?;
}
let xs = xs.apply(&self.layer_norm)?;
let mut lm_logits = Vec::with_capacity(self.num_codebooks);
for lm_head in self.lm_heads.iter() {
let logits = xs.apply(lm_head)?;
lm_logits.push(logits)
}
Ok(lm_logits)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}
#[derive(Debug, Clone)]
pub struct Model {
pub embed_prompts: candle_nn::Embedding,
pub enc_to_dec_proj: Option<Linear>,
pub decoder: Decoder,
pub text_encoder: t5::T5EncoderModel,
pub decoder_start_token_id: u32,
pub pad_token_id: u32,
pub audio_encoder: crate::models::dac::Model,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.text_encoder)?;
let decoder = Decoder::new(&cfg.decoder, vb.pp("decoder"))?;
let embed_prompts = candle_nn::embedding(
cfg.vocab_size,
cfg.decoder.hidden_size,
vb.pp("embed_prompts"),
)?;
let enc_to_dec_proj = if cfg.text_encoder.d_model != cfg.decoder.hidden_size {
let proj = linear(
cfg.text_encoder.d_model,
cfg.decoder.hidden_size,
true,
vb.pp("enc_to_dec_proj"),
)?;
Some(proj)
} else {
None
};
let audio_encoder =
crate::models::dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?;
Ok(Self {
decoder,
text_encoder,
embed_prompts,
enc_to_dec_proj,
decoder_start_token_id: cfg.decoder_start_token_id,
pad_token_id: cfg.pad_token_id,
audio_encoder,
})
}
/// Note that the returned tensor uses the CPU device.
pub fn generate(
&mut self,
prompt_tokens: &Tensor,
description_tokens: &Tensor,
mut lp: LogitsProcessor,
max_steps: usize,
) -> Result<Tensor> {
self.decoder.clear_kv_cache();
self.text_encoder.clear_kv_cache();
let encoded = self.text_encoder.forward(description_tokens)?;
let encoded = match self.enc_to_dec_proj.as_ref() {
None => encoded,
Some(proj) => encoded.apply(proj)?,
};
let prompt_hidden_states = prompt_tokens.apply(&self.embed_prompts)?;
let num_codebooks = self.decoder.num_codebooks;
let mut audio_tokens = vec![self.decoder_start_token_id; num_codebooks];
let mut all_audio_tokens = vec![vec![]; num_codebooks];
let prompt_len = prompt_hidden_states.dim(1)?;
for step in 0..max_steps {
let input_ids = Tensor::from_slice(
audio_tokens.as_slice(),
(1, num_codebooks, 1),
prompt_tokens.device(),
)?;
let (prompt_hidden_states, pos) = if step == 0 {
(Some(&prompt_hidden_states), 0)
} else {
(None, step + prompt_len)
};
let causal_mask = if pos == 0 {
self.prepare_causal_mask(prompt_len + 1, prompt_len + 1, input_ids.device())?
} else {
self.prepare_causal_mask(1, pos + 1, input_ids.device())?
};
let logits = self.decoder.forward(
&input_ids,
prompt_hidden_states,
Some(&causal_mask),
&encoded,
None,
pos,
)?;
for (logit_idx, logit) in logits.iter().enumerate() {
if logit_idx > step {
break;
}
if audio_tokens[logit_idx] != self.pad_token_id {
let logit = logit.i((0, logit.dim(1)? - 1))?;
let token = lp.sample(&logit)?;
audio_tokens[logit_idx] = token
}
}
if audio_tokens.iter().all(|v| v == &self.pad_token_id) {
break;
}
for (cb_idx, &token) in audio_tokens.iter().enumerate() {
if token != self.decoder_start_token_id && token != self.pad_token_id {
all_audio_tokens[cb_idx].push(token)
}
}
}
let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0);
all_audio_tokens.iter_mut().for_each(|v| {
v.resize(min_len, 0);
});
let all_audio_tokens = Tensor::new(all_audio_tokens, &candle::Device::Cpu)?;
Ok(all_audio_tokens)
}
fn prepare_causal_mask(
&self,
q_len: usize,
kv_len: usize,
device: &candle::Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..q_len)
.flat_map(|i| {
(0..kv_len).map(move |j| {
if i + kv_len < j + q_len {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
Tensor::from_slice(&mask, (q_len, kv_len), device)
}
}

View File

@ -361,7 +361,7 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let base_model = Model::new(cfg, vb.clone())?;
let lm_head = if vb.contains_tensor("lm_head") {
let lm_head = if vb.contains_tensor("lm_head.weight") {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
} else {
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)

View File

@ -404,7 +404,7 @@ impl SegformerEncoder {
stride,
num_channels,
hidden_size,
vb.pp(&format!("patch_embeddings.{}", i)),
vb.pp(format!("patch_embeddings.{}", i)),
)?);
let mut layers = Vec::with_capacity(config.depths[i]);
for j in 0..config.depths[i] {
@ -417,14 +417,14 @@ impl SegformerEncoder {
num_attention_heads,
sequence_reduction_ratio,
mlp_ratio,
vb.pp(&format!("block.{}.{}", i, j)),
vb.pp(format!("block.{}.{}", i, j)),
)?);
}
blocks.push(layers);
layer_norms.push(layer_norm(
hidden_size,
config.layer_norm_eps,
vb.pp(&format!("layer_norm.{}", i)),
vb.pp(format!("layer_norm.{}", i)),
)?);
}
Ok(Self {
@ -507,7 +507,7 @@ impl SegformerDecodeHead {
linear_c.push(SegformerMLP::new(
config,
hidden_size,
vb.pp(&format!("linear_c.{}", i)),
vb.pp(format!("linear_c.{}", i)),
)?);
}
let linear_fuse = conv2d_no_bias(

View File

@ -659,7 +659,7 @@ struct T5Stack {
impl T5Stack {
fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
let block = (0..cfg.num_layers)
.map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg))
.map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let final_layer_norm = T5LayerNorm::load(
cfg.d_model,

View File

@ -260,7 +260,7 @@ impl AudioEncoder {
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
@ -321,7 +321,7 @@ impl TextDecoder {
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
let blocks = (0..cfg.decoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;

View File

@ -50,3 +50,61 @@ pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
bboxes_for_class.truncate(current_index);
}
}
// Updates confidences starting at highest and comparing subsequent boxes.
fn update_confidences<D>(
bboxes_for_class: &[Bbox<D>],
updated_confidences: &mut [f32],
iou_threshold: f32,
sigma: f32,
) {
let len = bboxes_for_class.len();
for current_index in 0..len {
let current_bbox = &bboxes_for_class[current_index];
for index in (current_index + 1)..len {
let iou_val = iou(current_bbox, &bboxes_for_class[index]);
if iou_val > iou_threshold {
// Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
let decay = (-iou_val * iou_val / sigma).exp();
let updated_confidence = bboxes_for_class[index].confidence * decay;
updated_confidences[index] = updated_confidence;
}
}
}
}
// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
pub fn soft_non_maximum_suppression<D>(
bboxes: &mut [Vec<Bbox<D>>],
iou_threshold: Option<f32>,
confidence_threshold: Option<f32>,
sigma: Option<f32>,
) {
let iou_threshold = iou_threshold.unwrap_or(0.5);
let confidence_threshold = confidence_threshold.unwrap_or(0.1);
let sigma = sigma.unwrap_or(0.5);
for bboxes_for_class in bboxes.iter_mut() {
// Sort boxes by confidence in descending order
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut updated_confidences = bboxes_for_class
.iter()
.map(|bbox| bbox.confidence)
.collect::<Vec<_>>();
update_confidences(
bboxes_for_class,
&mut updated_confidences,
iou_threshold,
sigma,
);
// Update confidences, set to 0.0 if below threshold
for (i, &confidence) in updated_confidences.iter().enumerate() {
bboxes_for_class[i].confidence = if confidence < confidence_threshold {
0.0
} else {
confidence
};
}
}
}

Some files were not shown because too many files have changed in this diff Show More