mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
48 Commits
metal-gemm
...
0.7.0
Author | SHA1 | Date | |
---|---|---|---|
c2fca0ca11 | |||
844d45cde4 | |||
af2104078f | |||
5fc4f17727 | |||
c58c5d5b01 | |||
382c6b51af | |||
6eea45a761 | |||
ebf722b446 | |||
c09afc211c | |||
b60faebea4 | |||
72d649058b | |||
0cb0bd1dfa | |||
afb6575835 | |||
5635650d38 | |||
13b2a8a4a0 | |||
e3261216b1 | |||
c02b7c3272 | |||
86613c00e2 | |||
29e25c458d | |||
aafa24ed93 | |||
fdc2622686 | |||
ccdbe87639 | |||
2ec8729d51 | |||
e3c146ada6 | |||
1e96b8b695 | |||
a8288b7a72 | |||
6070278a31 | |||
b47c0bc475 | |||
14fd2d97e0 | |||
31a1075f4b | |||
236b29ff15 | |||
58197e1896 | |||
736d8eb752 | |||
7cff5898ec | |||
b75ef051cf | |||
c1b9e07e35 | |||
69fdcfe96a | |||
2b75dd9551 | |||
53ce65f706 | |||
68aa9c7320 | |||
35e5f31397 | |||
d3fe989d08 | |||
14db029494 | |||
6e6c1c99b0 | |||
b7d9af00cc | |||
59bbc0d287 | |||
dfdce2b602 | |||
500c9f2882 |
6
.github/workflows/python.yml
vendored
6
.github/workflows/python.yml
vendored
@ -18,9 +18,9 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest] # For now, only test on Linux
|
||||
steps:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
@ -65,4 +65,4 @@ jobs:
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
python -m pytest -s -v tests
|
||||
python -m pytest -s -v tests
|
||||
|
12
.github/workflows/rust-ci.yml
vendored
12
.github/workflows/rust-ci.yml
vendored
@ -1,6 +1,6 @@
|
||||
on:
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
@ -15,7 +15,7 @@ jobs:
|
||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||
rust: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
@ -34,7 +34,7 @@ jobs:
|
||||
os: [ubuntu-latest, windows-latest, macOS-latest]
|
||||
rust: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
@ -49,7 +49,7 @@ jobs:
|
||||
name: Rustfmt
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
@ -65,7 +65,7 @@ jobs:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
|
6
.gitignore
vendored
6
.gitignore
vendored
@ -40,3 +40,9 @@ candle-wasm-examples/*/package-lock.json
|
||||
candle-wasm-examples/**/config*.json
|
||||
.DS_Store
|
||||
.idea/*
|
||||
__pycache__
|
||||
out.safetensors
|
||||
out.wav
|
||||
bria.mp3
|
||||
bria.safetensors
|
||||
bria.wav
|
||||
|
20
Cargo.toml
20
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.6.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.7.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.7.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.7.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.7.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.7.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.7.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "=0.11.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.12.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
|
11
README.md
11
README.md
@ -63,7 +63,9 @@ We also provide a some command line based examples using state of the art models
|
||||
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
||||
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
|
||||
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
|
||||
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
|
||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||
Griffin based models from Google that mix attention with a RNN like state.
|
||||
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||
@ -118,6 +120,8 @@ We also provide a some command line based examples using state of the art models
|
||||
model using residual vector quantization.
|
||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
||||
text-to-speech.
|
||||
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
|
||||
model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
@ -206,7 +210,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- StarCoder, StarCoder2.
|
||||
- Phi 1, 1.5, 2, and 3.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma 2b and 7b.
|
||||
- Gemma v1 2b and 7b+, v2 2b and 9b.
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
@ -234,9 +238,10 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Whisper, multi-lingual speech-to-text.
|
||||
- EnCodec, audio compression model.
|
||||
- MetaVoice-1B, text-to-speech model.
|
||||
- Parler-TTS, text-to-speech model.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera.
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- SegFormer.
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::WithDType;
|
||||
use cudarc;
|
||||
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
|
||||
use cudarc::cudnn::safe::{ConvForward, Cudnn};
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
@ -87,7 +87,7 @@ pub(crate) fn launch_conv2d<
|
||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||
[params.b_size as i32, params.c_out as i32, h_out, w_out],
|
||||
)?;
|
||||
let conv2d = Conv2dForward {
|
||||
let conv2d = ConvForward {
|
||||
conv: &conv,
|
||||
x: &x,
|
||||
w: &w,
|
||||
|
@ -174,6 +174,7 @@ impl Map1 for Im2Col1D {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
struct Im2Col {
|
||||
h_k: usize,
|
||||
w_k: usize,
|
||||
@ -183,6 +184,7 @@ struct Im2Col {
|
||||
}
|
||||
|
||||
impl Im2Col {
|
||||
#[allow(unused)]
|
||||
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
||||
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
||||
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
||||
|
@ -141,28 +141,117 @@ impl<T> IndexOp<T> for Tensor
|
||||
where
|
||||
T: Into<TensorIndexer>,
|
||||
{
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||
/// let a = Tensor::new(&[
|
||||
/// [0., 1.],
|
||||
/// [2., 3.],
|
||||
/// [4., 5.]
|
||||
/// ], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.i(0)?;
|
||||
/// assert_eq!(b.shape().dims(), &[2]);
|
||||
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
|
||||
///
|
||||
/// let c = a.i(..2)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(c.to_vec2::<f64>()?, &[
|
||||
/// [0., 1.],
|
||||
/// [2., 3.]
|
||||
/// ]);
|
||||
///
|
||||
/// let d = a.i(1..)?;
|
||||
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(d.to_vec2::<f64>()?, &[
|
||||
/// [2., 3.],
|
||||
/// [4., 5.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, index: T) -> Result<Tensor, Error> {
|
||||
self.index(&[index.into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> IndexOp<(A,)> for Tensor
|
||||
where
|
||||
A: Into<TensorIndexer>,
|
||||
{
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||
/// let a = Tensor::new(&[
|
||||
/// [0f32, 1.],
|
||||
/// [2. , 3.],
|
||||
/// [4. , 5.]
|
||||
/// ], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.i((0,))?;
|
||||
/// assert_eq!(b.shape().dims(), &[2]);
|
||||
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
|
||||
///
|
||||
/// let c = a.i((..2,))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||
/// [0., 1.],
|
||||
/// [2., 3.]
|
||||
/// ]);
|
||||
///
|
||||
/// let d = a.i((1..,))?;
|
||||
/// assert_eq!(d.shape().dims(), &[2, 2]);
|
||||
/// assert_eq!(d.to_vec2::<f32>()?, &[
|
||||
/// [2., 3.],
|
||||
/// [4., 5.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
|
||||
self.index(&[a.into()])
|
||||
}
|
||||
}
|
||||
#[allow(non_snake_case)]
|
||||
impl<A, B> IndexOp<(A, B)> for Tensor
|
||||
where
|
||||
A: Into<TensorIndexer>,
|
||||
B: Into<TensorIndexer>,
|
||||
{
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, DType, Device, IndexOp};
|
||||
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.i((1, 0))?;
|
||||
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
|
||||
///
|
||||
/// let c = a.i((..2, 1))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2]);
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
///
|
||||
/// let d = a.i((2.., ..))?;
|
||||
/// assert_eq!(c.shape().dims(), &[2]);
|
||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
||||
self.index(&[a.into(), b.into()])
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! index_op_tuple {
|
||||
($($t:ident),+) => {
|
||||
($doc:tt, $($t:ident),+) => {
|
||||
#[allow(non_snake_case)]
|
||||
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
|
||||
where
|
||||
$($t: Into<TensorIndexer>,)*
|
||||
{
|
||||
#[doc=$doc]
|
||||
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
|
||||
self.index(&[$($t.into(),)*])
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
index_op_tuple!(A);
|
||||
index_op_tuple!(A, B);
|
||||
index_op_tuple!(A, B, C);
|
||||
index_op_tuple!(A, B, C, D);
|
||||
index_op_tuple!(A, B, C, D, E);
|
||||
index_op_tuple!(A, B, C, D, E, F);
|
||||
index_op_tuple!(A, B, C, D, E, F, G);
|
||||
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
|
||||
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);
|
||||
|
@ -65,6 +65,7 @@ pub mod scalar;
|
||||
pub mod shape;
|
||||
mod sort;
|
||||
mod storage;
|
||||
pub mod streaming;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
mod tensor_cat;
|
||||
@ -80,10 +81,11 @@ pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, Inp
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use indexer::IndexOp;
|
||||
pub use indexer::{IndexOp, TensorIndexer};
|
||||
pub use layout::Layout;
|
||||
pub use shape::{Shape, D};
|
||||
pub use storage::Storage;
|
||||
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
|
||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||
pub use tensor::{Tensor, TensorId};
|
||||
pub use variable::Var;
|
||||
|
@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
|
||||
use super::MetalError;
|
||||
|
||||
@ -22,7 +22,73 @@ impl DeviceId {
|
||||
}
|
||||
|
||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||
pub(crate) struct Commands {
|
||||
/// Single command queue for the entire device.
|
||||
command_queue: CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
command_buffer: CommandBuffer,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
command_buffer_index: usize,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
compute_per_buffer: usize,
|
||||
}
|
||||
|
||||
impl Commands {
|
||||
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 50,
|
||||
};
|
||||
Ok(Self {
|
||||
command_queue,
|
||||
command_buffer,
|
||||
command_buffer_index: 0,
|
||||
compute_per_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
|
||||
let mut command_buffer = self.command_buffer.to_owned();
|
||||
let mut flushed = false;
|
||||
if self.command_buffer_index > self.compute_per_buffer {
|
||||
self.command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
self.command_buffer = command_buffer.clone();
|
||||
self.command_buffer_index = 0;
|
||||
flushed = true;
|
||||
}
|
||||
self.command_buffer_index += 1;
|
||||
Ok((flushed, command_buffer))
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&mut self) -> Result<()> {
|
||||
match self.command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Already committed");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
self.command_buffer.commit();
|
||||
self.command_buffer.wait_until_completed();
|
||||
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
@ -33,27 +99,8 @@ pub struct MetalDevice {
|
||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||
pub(crate) device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
pub(crate) command_queue: CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||
/// to start to work).
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||
pub(crate) compute_per_buffer: usize,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
pub(crate) kernels: Arc<Kernels>,
|
||||
pub(crate) commands: Arc<RwLock<Commands>>,
|
||||
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
@ -67,9 +114,15 @@ pub struct MetalDevice {
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||
/// (strong_count = 1).
|
||||
pub(crate) buffers: AllocatedBuffers,
|
||||
pub(crate) buffers: Arc<RwLock<BufferMap>>,
|
||||
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
pub(crate) kernels: Arc<Kernels>,
|
||||
/// Seed for random number generation.
|
||||
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||
/// Whether to use the MLX matmul kernels instead of the MFA ones.
|
||||
pub(crate) use_mlx_mm: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -87,6 +140,10 @@ impl std::ops::Deref for MetalDevice {
|
||||
}
|
||||
|
||||
impl MetalDevice {
|
||||
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
|
||||
self.use_mlx_mm = use_mlx_mm
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
self.id
|
||||
}
|
||||
@ -95,44 +152,31 @@ impl MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
|
||||
self.drop_unused_buffers()?;
|
||||
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||
let (flushed, command_buffer) = commands.command_buffer()?;
|
||||
if flushed {
|
||||
self.drop_unused_buffers()?
|
||||
}
|
||||
*index += 1;
|
||||
Ok(command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Already committed");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
let mut commands = self.commands.write().map_err(MetalError::from)?;
|
||||
commands.wait_until_completed()
|
||||
}
|
||||
|
||||
pub fn kernels(&self) -> &Kernels {
|
||||
@ -180,6 +224,7 @@ impl MetalDevice {
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
@ -210,40 +255,6 @@ impl MetalDevice {
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &RwLockWriteGuard<BufferMap>,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best_buffer.cloned()
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
@ -252,7 +263,7 @@ impl MetalDevice {
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
if let Some(b) = find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
}
|
||||
@ -291,3 +302,23 @@ impl MetalDevice {
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &BufferMap,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
best_buffer.cloned()
|
||||
}
|
||||
|
@ -412,17 +412,42 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16_strided",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
||||
(DType::BF16, DType::I64) => "cast_bf16_i64_strided",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32_strided",
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8_strided",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(DType::F16, DType::I64) => "cast_f16_i64_strided",
|
||||
(DType::F16, DType::U32) => "cast_f16_u32_strided",
|
||||
(DType::F16, DType::U8) => "cast_f16_u8_strided",
|
||||
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F32, DType::I64) => "cast_f32_i64_strided",
|
||||
(DType::F32, DType::U32) => "cast_f32_u32_strided",
|
||||
(DType::F32, DType::U8) => "cast_f32_u8_strided",
|
||||
|
||||
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||
(DType::I64, DType::BF16) => "cast_i64_bf16_strided",
|
||||
(DType::I64, DType::F16) => "cast_i64_f16_strided",
|
||||
(DType::I64, DType::U32) => "cast_i64_u32_strided",
|
||||
(DType::I64, DType::U8) => "cast_i64_u8_strided",
|
||||
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16_strided",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16_strided",
|
||||
(DType::U32, DType::F32) => "cast_u32_f32_strided",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64_strided",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8_strided",
|
||||
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16_strided",
|
||||
(DType::U8, DType::F16) => "cast_u8_f16_strided",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32_strided",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64_strided",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16_strided",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32_strided",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32_strided",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32_strided",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
|
||||
}
|
||||
@ -1398,6 +1423,7 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
@ -1406,32 +1432,78 @@ impl BackendStorage for MetalStorage {
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
DType::F16 => "hgemm",
|
||||
DType::BF16 => "bgemm",
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
|
||||
}
|
||||
};
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("matmul");
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
if self.dtype == DType::BF16 {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
candle_metal_kernels::GemmDType::BF16,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else if self.device.use_mlx_mm {
|
||||
let dtype = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::GemmDType::F32,
|
||||
DType::F16 => candle_metal_kernels::GemmDType::F16,
|
||||
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
|
||||
dtype => {
|
||||
return Err(MetalError::Message(format!(
|
||||
"mlx matmul doesn't support {dtype:?}"
|
||||
))
|
||||
.into())
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
DType::F16 => "hgemm",
|
||||
dtype => {
|
||||
return Err(
|
||||
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&self.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
@ -1792,31 +1864,25 @@ impl BackendDevice for MetalDevice {
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_buffer));
|
||||
let command_buffer_index = Arc::new(RwLock::new(0));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 50,
|
||||
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
|
||||
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
|
||||
Ok(_) => true,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
4,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
)));
|
||||
let commands = device::Commands::new(command_queue)?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
command_buffer_index,
|
||||
compute_per_buffer,
|
||||
buffers,
|
||||
commands: Arc::new(RwLock::new(commands)),
|
||||
buffers: Arc::new(RwLock::new(HashMap::new())),
|
||||
kernels,
|
||||
seed,
|
||||
use_mlx_mm,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -304,6 +304,7 @@ impl Dim for usize {
|
||||
pub enum D {
|
||||
Minus1,
|
||||
Minus2,
|
||||
Minus(usize),
|
||||
}
|
||||
|
||||
impl D {
|
||||
@ -311,6 +312,7 @@ impl D {
|
||||
let dim = match self {
|
||||
Self::Minus1 => -1,
|
||||
Self::Minus2 => -2,
|
||||
Self::Minus(u) => -(*u as i32),
|
||||
};
|
||||
Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
@ -327,6 +329,7 @@ impl Dim for D {
|
||||
match self {
|
||||
Self::Minus1 if rank >= 1 => Ok(rank - 1),
|
||||
Self::Minus2 if rank >= 2 => Ok(rank - 2),
|
||||
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
|
||||
_ => Err(self.out_of_range(shape, op)),
|
||||
}
|
||||
}
|
||||
@ -336,6 +339,7 @@ impl Dim for D {
|
||||
match self {
|
||||
Self::Minus1 => Ok(rank),
|
||||
Self::Minus2 if rank >= 1 => Ok(rank - 1),
|
||||
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
|
||||
_ => Err(self.out_of_range(shape, op)),
|
||||
}
|
||||
}
|
||||
|
206
candle-core/src/streaming.rs
Normal file
206
candle-core/src/streaming.rs
Normal file
@ -0,0 +1,206 @@
|
||||
use crate::{Result, Shape, Tensor};
|
||||
|
||||
pub trait Dim: crate::shape::Dim + Copy {}
|
||||
impl<T: crate::shape::Dim + Copy> Dim for T {}
|
||||
|
||||
/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
|
||||
/// empty.
|
||||
#[derive(Clone)]
|
||||
pub struct StreamTensor(Option<Tensor>);
|
||||
|
||||
impl std::fmt::Debug for StreamTensor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.0 {
|
||||
Some(t) => write!(f, "{:?}", t.shape()),
|
||||
None => write!(f, "Empty"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<Option<Tensor>> for StreamTensor {
|
||||
fn from(value: Option<Tensor>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<Tensor> for StreamTensor {
|
||||
fn from(value: Tensor) -> Self {
|
||||
Self(Some(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<()> for StreamTensor {
|
||||
fn from(_value: ()) -> Self {
|
||||
Self(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamTensor {
|
||||
pub fn empty() -> Self {
|
||||
Self(None)
|
||||
}
|
||||
|
||||
pub fn from_tensor(tensor: Tensor) -> Self {
|
||||
Self(Some(tensor))
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> Option<&Shape> {
|
||||
self.0.as_ref().map(|t| t.shape())
|
||||
}
|
||||
|
||||
pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
|
||||
let xs = match (&self.0, &rhs.0) {
|
||||
(Some(lhs), Some(rhs)) => {
|
||||
let xs = Tensor::cat(&[lhs, rhs], dim)?;
|
||||
Some(xs)
|
||||
}
|
||||
(Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
|
||||
(None, None) => None,
|
||||
};
|
||||
Ok(Self(xs))
|
||||
}
|
||||
|
||||
pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||
match &self.0 {
|
||||
None => Ok(0),
|
||||
Some(v) => v.dim(dim),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.0 = None
|
||||
}
|
||||
|
||||
pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
|
||||
let t = match &self.0 {
|
||||
None => None,
|
||||
Some(t) => {
|
||||
let seq_len = t.dim(dim)?;
|
||||
if seq_len <= offset {
|
||||
None
|
||||
} else {
|
||||
let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
|
||||
Some(t)
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self(t))
|
||||
}
|
||||
|
||||
/// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
|
||||
/// returned in the first output and the remaining in the second output.
|
||||
pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
|
||||
match &self.0 {
|
||||
None => Ok((Self::empty(), Self::empty())),
|
||||
Some(t) => {
|
||||
let seq_len = t.dim(dim)?;
|
||||
let lhs_len = usize::min(seq_len, lhs_len);
|
||||
if lhs_len == 0 {
|
||||
Ok((Self::empty(), t.clone().into()))
|
||||
} else {
|
||||
let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
|
||||
let rhs_len = seq_len - lhs_len;
|
||||
let rhs = if rhs_len == 0 {
|
||||
Self::empty()
|
||||
} else {
|
||||
Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
|
||||
};
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_option(&self) -> Option<&Tensor> {
|
||||
self.0.as_ref()
|
||||
}
|
||||
|
||||
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
|
||||
match &self.0 {
|
||||
None => Ok(Self::empty()),
|
||||
Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
|
||||
/// some internal buffering so that enough data has been received for the module to be able to
|
||||
/// perform some operations.
|
||||
pub trait StreamingModule {
|
||||
// TODO: Should we also have a flush method?
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
|
||||
fn reset_state(&mut self);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum BinOp {
|
||||
Add,
|
||||
Mul,
|
||||
Sub,
|
||||
Div,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingBinOp {
|
||||
prev_lhs: StreamTensor,
|
||||
prev_rhs: StreamTensor,
|
||||
pub op: BinOp,
|
||||
pub dim: crate::D,
|
||||
}
|
||||
|
||||
impl StreamingBinOp {
|
||||
pub fn new(op: BinOp, dim: crate::D) -> Self {
|
||||
Self {
|
||||
prev_lhs: StreamTensor::empty(),
|
||||
prev_rhs: StreamTensor::empty(),
|
||||
op,
|
||||
dim,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_state(&mut self) {
|
||||
self.prev_lhs.reset();
|
||||
self.prev_rhs.reset();
|
||||
}
|
||||
|
||||
pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
|
||||
match self.op {
|
||||
BinOp::Add => Tensor::add(lhs, rhs),
|
||||
BinOp::Mul => Tensor::mul(lhs, rhs),
|
||||
BinOp::Sub => Tensor::sub(lhs, rhs),
|
||||
BinOp::Div => Tensor::div(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
|
||||
let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
|
||||
let lhs_len = lhs.seq_len(self.dim)?;
|
||||
let rhs_len = rhs.seq_len(self.dim)?;
|
||||
let common_len = usize::min(lhs_len, rhs_len);
|
||||
let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
|
||||
let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
|
||||
let ys = match (lhs.0, rhs.0) {
|
||||
(Some(lhs), Some(rhs)) => {
|
||||
let ys = self.forward(&lhs, &rhs)?;
|
||||
StreamTensor::from_tensor(ys)
|
||||
}
|
||||
(None, None) => StreamTensor::empty(),
|
||||
(lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
|
||||
};
|
||||
self.prev_lhs = prev_lhs;
|
||||
self.prev_rhs = prev_rhs;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple wrapper that doesn't do any buffering.
|
||||
pub struct Map<T: crate::Module>(T);
|
||||
|
||||
impl<T: crate::Module> StreamingModule for Map<T> {
|
||||
fn reset_state(&mut self) {}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
xs.apply(&self.0)
|
||||
}
|
||||
}
|
@ -370,6 +370,15 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor with all the elements having the same specified value. Note that
|
||||
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||
/// [3.5, 3.5, 3.5, 3.5],
|
||||
/// [3.5, 3.5, 3.5, 3.5],
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
pub fn full<D: crate::WithDType, S: Into<Shape>>(
|
||||
value: D,
|
||||
shape: S,
|
||||
@ -379,6 +388,13 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor from an iterator.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_iter<D: crate::WithDType>(
|
||||
iter: impl IntoIterator<Item = D>,
|
||||
device: &Device,
|
||||
@ -390,12 +406,26 @@ impl Tensor {
|
||||
|
||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// difference `1` from `start`.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
||||
Self::arange_step(start, end, D::one(), device)
|
||||
}
|
||||
|
||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// difference `step` from `start`.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn arange_step<D: crate::WithDType>(
|
||||
start: D,
|
||||
end: D,
|
||||
@ -441,6 +471,16 @@ impl Tensor {
|
||||
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
||||
/// in this vector must be the same as the number of elements defined by the shape.
|
||||
/// If the device is cpu, no data copy is made.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||
/// [1., 2., 3.],
|
||||
/// [4., 5., 6.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
@ -451,6 +491,17 @@ impl Tensor {
|
||||
|
||||
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
||||
/// in this vector must be the same as the number of elements defined by the shape.
|
||||
///```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
|
||||
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
|
||||
///
|
||||
/// assert_eq!(a.to_vec2::<f64>()?, &[
|
||||
/// [2., 3., 4.],
|
||||
/// [5., 6., 7.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
@ -732,6 +783,30 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
/// ```
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[
|
||||
/// [0f32, 1., 2.],
|
||||
/// [3. , 4., 5.],
|
||||
/// [6. , 7., 8.]
|
||||
/// ], &Device::Cpu)?;
|
||||
///
|
||||
/// let b = a.narrow(0, 1, 2)?;
|
||||
/// assert_eq!(b.shape().dims(), &[2, 3]);
|
||||
/// assert_eq!(b.to_vec2::<f32>()?, &[
|
||||
/// [3., 4., 5.],
|
||||
/// [6., 7., 8.]
|
||||
/// ]);
|
||||
///
|
||||
/// let c = a.narrow(1, 1, 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[3, 1]);
|
||||
/// assert_eq!(c.to_vec2::<f32>()?, &[
|
||||
/// [1.],
|
||||
/// [4.],
|
||||
/// [7.]
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||
@ -1950,7 +2025,11 @@ impl Tensor {
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
_ => {
|
||||
bail!("not implemented yet")
|
||||
bail!(
|
||||
"not implemented yet, self.device: {:?}, device: {:?}",
|
||||
self.device(),
|
||||
device
|
||||
)
|
||||
}
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||
|
@ -193,6 +193,19 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
tensor.sign()?.to_vec1::<f32>()?,
|
||||
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
|
||||
);
|
||||
let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
|
||||
let y = tensor.elu(2.)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
);
|
||||
// This test failed on metal prior to the following PR:
|
||||
// https://github.com/huggingface/candle/pull/2490
|
||||
let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, -1.7293, 0.0000, 3.0000]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -67,6 +67,7 @@ onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
||||
[[example]]
|
||||
@ -101,6 +102,10 @@ required-features = ["candle-datasets"]
|
||||
name = "llama2-c"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "mimi"
|
||||
required-features = ["mimi"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["encodec"]
|
||||
@ -108,3 +113,7 @@ required-features = ["encodec"]
|
||||
[[example]]
|
||||
name = "depth_anything_v2"
|
||||
required-features = ["depth_anything_v2"]
|
||||
|
||||
[[example]]
|
||||
name = "silero-vad"
|
||||
required-features = ["onnx"]
|
||||
|
20
candle-examples/examples/based/README.md
Normal file
20
candle-examples/examples/based/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-based
|
||||
|
||||
Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers.
|
||||
|
||||
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
|
||||
|
||||
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100
|
||||
|
||||
Flying monkeys are a common sight in the wild, but they are also a threat to humans.
|
||||
|
||||
The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem.
|
||||
|
||||
"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California
|
||||
|
||||
```
|
275
candle-examples/examples/based/main.rs
Normal file
275
candle-examples/examples/based/main.rs
Normal file
@ -0,0 +1,275 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::based::Model;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "360m")]
|
||||
W360m,
|
||||
#[value(name = "1b")]
|
||||
W1b,
|
||||
#[value(name = "1b-50b")]
|
||||
W1b50b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "refs/pr/1")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long, default_value = "360m")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.which {
|
||||
Which::W360m => "hazyresearch/based-360m".to_string(),
|
||||
Which::W1b => "hazyresearch/based-1b".to_string(),
|
||||
Which::W1b50b => "hazyresearch/based-1b-50b".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let config_file = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![repo.get("model.safetensors")?],
|
||||
};
|
||||
|
||||
let repo = api.model("openai-community/gpt2".to_string());
|
||||
let tokenizer_file = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
|
||||
let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
if args.which == Which::W1b50b {
|
||||
vb = vb.pp("model");
|
||||
};
|
||||
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -7,7 +7,7 @@ quantization.
|
||||
## Running one example
|
||||
|
||||
```bash
|
||||
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
||||
cargo run --example encodec --features encodec --release -- code-to-audio \
|
||||
candle-examples/examples/encodec/jfk-codes.safetensors \
|
||||
jfk.wav
|
||||
```
|
||||
|
Binary file not shown.
20
candle-examples/examples/fastvit/README.md
Normal file
20
candle-examples/examples/fastvit/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-fastvit
|
||||
|
||||
[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189).
|
||||
This candle implementation uses a pre-trained FastViT network for inference. The
|
||||
classification head has been trained on the ImageNet dataset and returns the
|
||||
probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example fastvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which sa12
|
||||
|
||||
loaded image Tensor[dims 3, 256, 256; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 52.67%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 7.93%
|
||||
unicycle, monocycle : 3.46%
|
||||
maillot : 1.32%
|
||||
crash helmet : 1.28%
|
||||
```
|
102
candle-examples/examples/fastvit/main.rs
Normal file
102
candle-examples/examples/fastvit/main.rs
Normal file
@ -0,0 +1,102 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::fastvit;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
T8,
|
||||
T12,
|
||||
S12,
|
||||
SA12,
|
||||
SA24,
|
||||
SA36,
|
||||
MA36,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::T8 => "t8",
|
||||
Self::T12 => "t12",
|
||||
Self::S12 => "s12",
|
||||
Self::SA12 => "sa12",
|
||||
Self::SA24 => "sa24",
|
||||
Self::SA36 => "sa36",
|
||||
Self::MA36 => "ma36",
|
||||
};
|
||||
format!("timm/fastvit_{}.apple_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> fastvit::Config {
|
||||
match self {
|
||||
Self::T8 => fastvit::Config::t8(),
|
||||
Self::T12 => fastvit::Config::t12(),
|
||||
Self::S12 => fastvit::Config::s12(),
|
||||
Self::SA12 => fastvit::Config::sa12(),
|
||||
Self::SA24 => fastvit::Config::sa24(),
|
||||
Self::SA36 => fastvit::Config::sa36(),
|
||||
Self::MA36 => fastvit::Config::ma36(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::S12)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image(args.image, 256)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = fastvit::fastvit(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -147,8 +147,8 @@ fn run(args: Args) -> Result<()> {
|
||||
println!("CLIP\n{clip_emb}");
|
||||
let img = {
|
||||
let model_file = match model {
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.sft")?,
|
||||
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
|
||||
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
|
||||
};
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
@ -189,7 +189,7 @@ fn run(args: Args) -> Result<()> {
|
||||
println!("latent img\n{img}");
|
||||
|
||||
let img = {
|
||||
let model_file = bf_repo.get("ae.sft")?;
|
||||
let model_file = bf_repo.get("ae.safetensors")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = match model {
|
||||
Model::Dev => flux::autoencoder::Config::dev(),
|
||||
|
6
candle-examples/examples/flux/t5_tokenizer.py
Normal file
6
candle-examples/examples/flux/t5_tokenizer.py
Normal file
@ -0,0 +1,6 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
BASE_MODEL = "google/t5-v1_1-xxl"
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
||||
# The tokenizer will be saved in /tmp/tokenizer/tokenizer.json
|
||||
tokenizer.save_pretrained("/tmp/tokenizer/")
|
@ -1,27 +1,27 @@
|
||||
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||
models published by Google Deepmind with a 2b and a 7b variant.
|
||||
|
||||
In order to use the example below, you have to accept the license on the
|
||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||
your access token via the [HuggingFace cli login
|
||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||
models published by Google Deepmind with a 2b and a 7b variant for the first
|
||||
version, and a 2b and a 9b variant for v2.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
|
||||
fn count_primes(max_n: usize) -> usize {
|
||||
let mut primes = vec![true; max_n];
|
||||
for i in 2..=max_n {
|
||||
if primes[i] {
|
||||
for j in i * i..max_n {
|
||||
primes[j] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
primes.len()
|
||||
}
|
||||
$ cargo run --example gemma --features cuda -r -- \
|
||||
--prompt "Here is a proof that square root of 2 is not rational: "
|
||||
|
||||
Here is a proof that square root of 2 is not rational:
|
||||
|
||||
Let us assume it to be rational. Then, we can write √2 = p/q where q ≠ 0 and p and q are integers with no common factors other than 1. Squaring both sides gives us (p/q)^2 = 2 or p^2/q^2 = 2. This implies that p^2 is divisible by 2, which means that p must be even. Let us write p = 2m where m is an integer. Substituting this in the above equation we get:
|
||||
|
||||
(p^2)/q^2 = 2 or (4m^2)/q^2 = 2 or q^2/2m^2 = 1 which implies that q^2 must be divisible by 2, and hence q is even. This contradicts our assumption that p and q have no common factors other than 1. Hence we conclude that √2 cannot be rational.
|
||||
```
|
||||
|
||||
## Access restrictions
|
||||
|
||||
In order to use the v1 examples, you have to accept the license on the
|
||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||
your access token via the [HuggingFace cli login
|
||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||
|
||||
|
||||
|
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::gemma::{Config, Model};
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -38,6 +39,46 @@ enum Which {
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
#[value(name = "2-2b")]
|
||||
BaseV2_2B,
|
||||
#[value(name = "2-2b-it")]
|
||||
InstructV2_2B,
|
||||
#[value(name = "2-9b")]
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_v1(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B
|
||||
| Self::Base7B
|
||||
| Self::Instruct2B
|
||||
| Self::Instruct7B
|
||||
| Self::InstructV1_1_2B
|
||||
| Self::InstructV1_1_7B
|
||||
| Self::CodeBase2B
|
||||
| Self::CodeBase7B
|
||||
| Self::CodeInstruct2B
|
||||
| Self::CodeInstruct7B => true,
|
||||
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
@ -191,7 +232,7 @@ struct Args {
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
#[arg(long, default_value = "2-2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
@ -239,6 +280,10 @@ fn main() -> Result<()> {
|
||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
|
||||
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
@ -263,7 +308,6 @@ fn main() -> Result<()> {
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
@ -273,7 +317,15 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||
let model = if args.which.is_v1() {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
} else {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
77
candle-examples/examples/glm4/README.org
Normal file
77
candle-examples/examples/glm4/README.org
Normal file
@ -0,0 +1,77 @@
|
||||
* GLM4
|
||||
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.
|
||||
|
||||
- [[https://github.com/THUDM/GLM4][Github]]
|
||||
- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]]
|
||||
|
||||
** Running with ~cuda~
|
||||
|
||||
#+begin_src shell
|
||||
cargo run --example glm4 --release --features cuda
|
||||
#+end_src
|
||||
|
||||
** Running with ~cpu~
|
||||
#+begin_src shell
|
||||
cargo run --example glm4 --release -- --cpu
|
||||
#+end_src
|
||||
|
||||
** Output Example
|
||||
#+begin_src shell
|
||||
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
|
||||
Finished release [optimized] target(s) in 0.24s
|
||||
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
|
||||
cache path .
|
||||
retrieved the files in 6.88963ms
|
||||
loaded the model in 6.113752297s
|
||||
starting the inference loop
|
||||
[欢迎使用GLM-4,请输入prompt]
|
||||
请你告诉我什么是FFT
|
||||
266 tokens generated (34.50 token/s)
|
||||
Result:
|
||||
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。
|
||||
|
||||
具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
|
||||
|
||||
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
# 创建一个时域信号
|
||||
t = np.linspace(0, 1, num=100)
|
||||
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
|
||||
|
||||
# 对该信号做FFT变换,并计算其幅值谱
|
||||
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
|
||||
|
||||
```
|
||||
|
||||
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
|
||||
#+end_src
|
||||
|
||||
This example will read prompt from stdin
|
||||
|
||||
* Citation
|
||||
#+begin_src
|
||||
@misc{glm2024chatglm,
|
||||
title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools},
|
||||
author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
|
||||
year={2024},
|
||||
eprint={2406.12793},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
|
||||
}
|
||||
#+end_src
|
||||
|
||||
#+begin_src
|
||||
@misc{wang2023cogvlm,
|
||||
title={CogVLM: Visual Expert for Pretrained Language Models},
|
||||
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
|
||||
year={2023},
|
||||
eprint={2311.03079},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
#+end_src
|
255
candle-examples/examples/glm4/main.rs
Normal file
255
candle-examples/examples/glm4/main.rs
Normal file
@ -0,0 +1,255 @@
|
||||
use candle_transformers::models::glm4::*;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
verbose_prompt: bool,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
verbose_prompt,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
|
||||
use std::io::BufRead;
|
||||
use std::io::BufReader;
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
println!("[欢迎使用GLM-4,请输入prompt]");
|
||||
let stdin = std::io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
for line in reader.lines() {
|
||||
let line = line.expect("Failed to read line");
|
||||
|
||||
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
|
||||
if tokens.is_empty() {
|
||||
panic!("Empty prompts are not supported in the chatglm model.")
|
||||
}
|
||||
if self.verbose_prompt {
|
||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
println!("{id:7} -> '{token}'");
|
||||
}
|
||||
}
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => panic!("cannot find the endoftext token"),
|
||||
};
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
|
||||
std::io::stdout().flush().expect("output flush error");
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
||||
let mut count = 0;
|
||||
let mut result = vec![];
|
||||
for index in 0..sample_len {
|
||||
count += 1;
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self
|
||||
.tokenizer
|
||||
.decode(&[next_token], true)
|
||||
.expect("Token error");
|
||||
if self.verbose_prompt {
|
||||
println!(
|
||||
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
|
||||
count, next_token, token
|
||||
);
|
||||
}
|
||||
result.push(token);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
println!("Result:");
|
||||
for tokens in result {
|
||||
print!("{tokens}");
|
||||
}
|
||||
self.model.reset_kv_cache(); // clean the cache
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(name = "cache", short, long, default_value = ".")]
|
||||
cache_path: String,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 8192)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.2)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.6),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
println!("cache path {}", args.cache_path);
|
||||
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
|
||||
.build()
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => "THUDM/glm-4-9b".to_string(),
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("THUDM/codegeex4-all-9b".to_string())
|
||||
.get("tokenizer.json")
|
||||
.map_err(anyhow::Error::msg)?,
|
||||
};
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::glm4();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
args.verbose_prompt,
|
||||
&device,
|
||||
dtype,
|
||||
);
|
||||
pipeline.run(args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
20
candle-examples/examples/granite/README.md
Normal file
20
candle-examples/examples/granite/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-granite LLMs from IBM Research
|
||||
|
||||
[Granite](https://www.ibm.com/granite) is a family of Large Language Models built for business, to help drive trust and scalability in AI-driven applications.
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example granite --features metal -r -- --model-type "granite7b-instruct" \
|
||||
--prompt "Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind"
|
||||
|
||||
Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind competitors.
|
||||
|
||||
In recent years, there has been significant interest in quantum computing due to its potential to revolutionize various fields, including drug discovery, cryptography, and optimization problems. Quantum computers, which leverage the principles of quantum mechanics, differ fundamentally from classical computers. Here are some of the key differences:
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
There are two different modalities for the Granite family models: Language and Code.
|
||||
|
||||
### Granite for language
|
||||
1. [Granite 7b Instruct](https://huggingface.co/ibm-granite/granite-7b-instruct)
|
251
candle-examples/examples/granite/main.rs
Normal file
251
candle-examples/examples/granite/main.rs
Normal file
@ -0,0 +1,251 @@
|
||||
// An implementation of different Granite models https://www.ibm.com/granite
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
use candle_transformers::models::granite as model;
|
||||
use model::{Granite, GraniteConfig};
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum GraniteModel {
|
||||
Granite7bInstruct,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
#[arg(long)]
|
||||
no_kv_cache: bool,
|
||||
|
||||
/// The initial prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use different dtype than f16
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "granite7b-instruct")]
|
||||
model_type: GraniteModel,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = match args.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (granite, tokenizer_filename, mut cache, config) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.model_type {
|
||||
GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(),
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let filenames = match args.model_type {
|
||||
GraniteModel::Granite7bInstruct => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
};
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(
|
||||
Granite::load(vb, &config)?,
|
||||
tokenizer_filename,
|
||||
cache,
|
||||
config,
|
||||
)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = config.eos_token_id.or_else(|| {
|
||||
tokenizer
|
||||
.token_to_id(EOS_TOKEN)
|
||||
.map(model::GraniteEosToks::Single)
|
||||
});
|
||||
|
||||
let default_prompt = match args.model_type {
|
||||
GraniteModel::Granite7bInstruct => DEFAULT_PROMPT,
|
||||
};
|
||||
|
||||
let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
println!("Starting the inference loop:");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
let use_cache_kv = cache.use_kv_cache;
|
||||
|
||||
(0..args.sample_len)
|
||||
.inspect(|index| {
|
||||
if *index == 1 {
|
||||
start_gen = Instant::now();
|
||||
}
|
||||
})
|
||||
.try_for_each(|index| -> Result<()> {
|
||||
let (context_size, context_index) = if use_cache_kv && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = granite
|
||||
.forward(&input, context_index, &mut cache)?
|
||||
.squeeze(0)?;
|
||||
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
index_pos += ctxt.len();
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id {
|
||||
if next_token == eos_tok_id {
|
||||
return Err(E::msg("EOS token found"));
|
||||
}
|
||||
} else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id {
|
||||
if eos_ids.contains(&next_token) {
|
||||
return Err(E::msg("EOS token found"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.unwrap_or(());
|
||||
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
token_generated,
|
||||
(token_generated - 1) as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -14,6 +14,7 @@ use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::llama::LlamaEosToks;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -219,9 +220,16 @@ fn main() -> Result<()> {
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
if Some(next_token) == config.eos_token_id {
|
||||
break;
|
||||
match config.eos_token_id {
|
||||
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||
break;
|
||||
}
|
||||
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||
break;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
if rank == 0 {
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
|
@ -43,6 +43,14 @@ def import_protobuf(error_message=""):
|
||||
else:
|
||||
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
|
||||
|
||||
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
|
||||
if add_prefix_space:
|
||||
prepend_scheme = "always"
|
||||
if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
|
||||
prepend_scheme = "first"
|
||||
else:
|
||||
prepend_scheme = "never"
|
||||
return prepend_scheme
|
||||
|
||||
class SentencePieceExtractor:
|
||||
"""
|
||||
@ -519,13 +527,15 @@ class SpmConverter(Converter):
|
||||
)
|
||||
|
||||
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
||||
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
||||
|
||||
def post_processor(self):
|
||||
return None
|
||||
|
||||
def decoder(self, replacement, add_prefix_space):
|
||||
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
||||
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
|
||||
|
||||
def converted(self) -> Tokenizer:
|
||||
tokenizer = self.tokenizer(self.proto)
|
||||
@ -636,7 +646,8 @@ class DebertaV2Converter(SpmConverter):
|
||||
list_pretokenizers = []
|
||||
if self.original_tokenizer.split_by_punct:
|
||||
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
|
||||
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
|
||||
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
|
||||
return pre_tokenizers.Sequence(list_pretokenizers)
|
||||
|
||||
def normalizer(self, proto):
|
||||
@ -929,10 +940,11 @@ class PegasusConverter(SpmConverter):
|
||||
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
|
||||
|
||||
def pre_tokenizer(self, replacement, add_prefix_space):
|
||||
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
||||
return pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.WhitespaceSplit(),
|
||||
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
|
||||
pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
|
||||
]
|
||||
)
|
||||
|
||||
|
20
candle-examples/examples/mimi/README.md
Normal file
20
candle-examples/examples/mimi/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-mimi
|
||||
|
||||
[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
|
||||
compression model using an encoder/decoder architecture with residual vector
|
||||
quantization. The candle implementation supports streaming meaning that it's
|
||||
possible to encode or decode a stream of audio tokens on the flight to provide
|
||||
low latency interaction with an audio model.
|
||||
|
||||
## Running one example
|
||||
|
||||
Generating some audio tokens from an audio files.
|
||||
```bash
|
||||
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
|
||||
cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
|
||||
```
|
||||
|
||||
And decoding the audio tokens back into a sound file.
|
||||
```bash
|
||||
cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
|
||||
```
|
275
candle-examples/examples/mimi/audio_io.rs
Normal file
275
candle-examples/examples/mimi/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
||||
#![allow(unused)]
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const SAMPLE_RATE: usize = 24_000;
|
||||
|
||||
pub(crate) struct AudioOutputData_ {
|
||||
resampled_data: std::collections::VecDeque<f32>,
|
||||
resampler: rubato::FastFixedIn<f32>,
|
||||
output_buffer: Vec<f32>,
|
||||
input_buffer: Vec<f32>,
|
||||
input_len: usize,
|
||||
}
|
||||
|
||||
impl AudioOutputData_ {
|
||||
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||
let resampler = rubato::FastFixedIn::new(
|
||||
resample_ratio,
|
||||
f64::max(resample_ratio, 1.0),
|
||||
rubato::PolynomialDegree::Septic,
|
||||
1024,
|
||||
1,
|
||||
)?;
|
||||
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||
Ok(Self {
|
||||
resampled_data,
|
||||
resampler,
|
||||
input_buffer,
|
||||
output_buffer,
|
||||
input_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
use rubato::Resampler;
|
||||
self.output_buffer.fill(0.);
|
||||
self.input_buffer.fill(0.);
|
||||
self.resampler.reset();
|
||||
self.resampled_data.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||
while let Some(elem) = self.resampled_data.pop_back() {
|
||||
data.push(elem);
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.resampled_data.is_empty()
|
||||
}
|
||||
|
||||
// Assumes that the input buffer is large enough.
|
||||
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||
self.input_len += samples.len()
|
||||
}
|
||||
|
||||
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pos_in = 0;
|
||||
loop {
|
||||
let rem = self.input_buffer.len() - self.input_len;
|
||||
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||
pos_in = pos_end;
|
||||
if self.input_len < self.input_buffer.len() {
|
||||
break;
|
||||
}
|
||||
let (_, out_len) = self.resampler.process_into_buffer(
|
||||
&[&self.input_buffer],
|
||||
&mut [&mut self.output_buffer],
|
||||
None,
|
||||
)?;
|
||||
for &elem in self.output_buffer[..out_len].iter() {
|
||||
self.resampled_data.push_front(elem)
|
||||
}
|
||||
self.input_len = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||
|
||||
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio output stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_output_device()
|
||||
.context("no output device available")?;
|
||||
let mut supported_configs_range = device.supported_output_configs()?;
|
||||
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||
None => device
|
||||
.supported_output_configs()?
|
||||
.next()
|
||||
.context("no audio output available")?,
|
||||
Some(config_range) => config_range,
|
||||
};
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
let channels = config.channels as usize;
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
SAMPLE_RATE,
|
||||
config.sample_rate.0 as usize,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
data.fill(0.);
|
||||
let mut ad = ad.lock().unwrap();
|
||||
let mut last_elem = 0f32;
|
||||
for (idx, elem) in data.iter_mut().enumerate() {
|
||||
if idx % channels == 0 {
|
||||
match ad.resampled_data.pop_back() {
|
||||
None => break,
|
||||
Some(v) => {
|
||||
last_elem = v;
|
||||
*elem = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*elem = last_elem
|
||||
}
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio input stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.context("no input device available")?;
|
||||
let mut supported_configs_range = device.supported_input_configs()?;
|
||||
let config_range = supported_configs_range
|
||||
.find(|c| c.channels() == 1)
|
||||
.context("no audio input available")?;
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
config.sample_rate.0 as usize,
|
||||
SAMPLE_RATE,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mut ad = ad.lock().unwrap();
|
||||
if let Err(err) = ad.push_samples(data) {
|
||||
eprintln!("error processing audio input {err:?}")
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pcm_out =
|
||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||
|
||||
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
|
||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||
let mut pos_in = 0;
|
||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||
let (in_len, out_len) =
|
||||
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||
pos_in += in_len;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
if pos_in < pcm_in.len() {
|
||||
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||
Some(&[&pcm_in[pos_in..]]),
|
||||
&mut output_buffer,
|
||||
None,
|
||||
)?;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
Ok(pcm_out)
|
||||
}
|
131
candle-examples/examples/mimi/main.rs
Normal file
131
candle-examples/examples/mimi/main.rs
Normal file
@ -0,0 +1,131 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::mimi::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
mod audio_io;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some mimi tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some mimi tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("kyutai/mimi".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let config = Config::v0_1(None);
|
||||
let mut model = Model::new(config, vb)?;
|
||||
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
codes.get("codes").expect("no codes in input file").clone()
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let pcm = if args.in_file == "-" {
|
||||
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||
let mut pcms = vec![];
|
||||
let stdin = std::thread::spawn(|| {
|
||||
let mut s = String::new();
|
||||
std::io::stdin().read_line(&mut s)
|
||||
});
|
||||
while !stdin.is_finished() {
|
||||
let input = input_audio.lock().unwrap().take_all();
|
||||
if input.is_empty() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
pcms.push(input)
|
||||
}
|
||||
drop(stream);
|
||||
pcms.concat()
|
||||
} else {
|
||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
|
||||
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
|
||||
} else {
|
||||
pcm
|
||||
}
|
||||
};
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
codes.save_safetensors("codes", &args.out_file)?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
if args.out_file == "-" {
|
||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||
{
|
||||
let mut ad = ad.lock().unwrap();
|
||||
ad.push_samples(&pcm)?;
|
||||
}
|
||||
loop {
|
||||
let ad = ad.lock().unwrap();
|
||||
if ad.is_empty() {
|
||||
break;
|
||||
}
|
||||
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||
// playing (the callback doesn't seem to be called anymore).
|
||||
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
}
|
||||
drop(stream)
|
||||
} else {
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
28
candle-examples/examples/mobileclip/README.md
Normal file
28
candle-examples/examples/mobileclip/README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# candle-mobileclip
|
||||
|
||||
MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders.
|
||||
|
||||
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049)
|
||||
|
||||
|
||||
## Running on an example on cpu
|
||||
|
||||
```
|
||||
$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
|
||||
|
||||
softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6]
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
|
||||
Probability: 0.0025% Text: a cycling race
|
||||
Probability: 0.0004% Text: a photo of two cats
|
||||
Probability: 99.9971% Text: a robot holding a candle
|
||||
|
||||
|
||||
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
Probability: 99.9974% Text: a cycling race
|
||||
Probability: 0.0024% Text: a photo of two cats
|
||||
Probability: 0.0002% Text: a robot holding a candle
|
||||
```
|
192
candle-examples/examples/mobileclip/main.rs
Normal file
192
candle-examples/examples/mobileclip/main.rs
Normal file
@ -0,0 +1,192 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use candle_transformers::models::mobileclip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
S1,
|
||||
S2,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_name(&self) -> String {
|
||||
let name = match self {
|
||||
Self::S1 => "S1",
|
||||
Self::S2 => "S2",
|
||||
};
|
||||
format!("apple/MobileCLIP-{}-OpenCLIP", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> mobileclip::MobileClipConfig {
|
||||
match self {
|
||||
Self::S1 => mobileclip::MobileClipConfig::s1(),
|
||||
Self::S2 => mobileclip::MobileClipConfig::s2(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
images: Option<Vec<String>>,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
#[arg(long, use_value_delimiter = true)]
|
||||
sequences: Option<Vec<String>>,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::S1)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
fn load_images<T: AsRef<std::path::Path>>(
|
||||
paths: &Vec<T>,
|
||||
image_size: usize,
|
||||
) -> anyhow::Result<Tensor> {
|
||||
let mut images = vec![];
|
||||
|
||||
for path in paths {
|
||||
let tensor = candle_examples::imagenet::load_image_with_std_mean(
|
||||
path,
|
||||
image_size,
|
||||
&[0.0, 0.0, 0.0],
|
||||
&[1.0, 1.0, 1.0],
|
||||
)?;
|
||||
images.push(tensor);
|
||||
}
|
||||
|
||||
let images = Tensor::stack(&images, 0)?;
|
||||
|
||||
Ok(images)
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let model_name = args.which.model_name();
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
|
||||
let model_file = if args.use_pth {
|
||||
api.get("open_clip_pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("open_clip_model.safetensors")?
|
||||
};
|
||||
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
|
||||
let config = &args.which.config();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
None => vec![
|
||||
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
|
||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
|
||||
|
||||
let vb = if args.use_pth {
|
||||
VarBuilder::from_pth(&model_file, DType::F32, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
|
||||
};
|
||||
|
||||
let model = mobileclip::MobileClipModel::new(vb, config)?;
|
||||
|
||||
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
|
||||
|
||||
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
|
||||
|
||||
let softmax_image = softmax(&logits_per_image, 1)?;
|
||||
|
||||
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
println!("softmax_image_vec: {:?}", softmax_image_vec);
|
||||
|
||||
let probability_vec = softmax_image_vec
|
||||
.iter()
|
||||
.map(|v| v * 100.0)
|
||||
.collect::<Vec<f32>>();
|
||||
|
||||
let probability_per_image = probability_vec.len() / vec_imgs.len();
|
||||
|
||||
for (i, img) in vec_imgs.iter().enumerate() {
|
||||
let start = i * probability_per_image;
|
||||
let end = start + probability_per_image;
|
||||
let prob = &probability_vec[start..end];
|
||||
println!("\n\nResults for image: {}\n", img);
|
||||
|
||||
for (i, p) in prob.iter().enumerate() {
|
||||
println!("Probability: {:.4}% Text: {}", p, vec_seq[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn tokenize_sequences(
|
||||
sequences: Option<Vec<String>>,
|
||||
tokenizer: &Tokenizer,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<(Tensor, Vec<String>)> {
|
||||
// let pad_id = *tokenizer
|
||||
// .get_vocab(true)
|
||||
// .get("<|endoftext|>")
|
||||
// .ok_or(E::msg("No pad token"))?;
|
||||
|
||||
// The model does not work well if the text is padded using the <|endoftext|> token, using 0
|
||||
// as the original OpenCLIP code.
|
||||
let pad_id = 0;
|
||||
|
||||
let vec_seq = match sequences {
|
||||
Some(seq) => seq,
|
||||
None => vec![
|
||||
"a cycling race".to_string(),
|
||||
"a photo of two cats".to_string(),
|
||||
"a robot holding a candle".to_string(),
|
||||
],
|
||||
};
|
||||
|
||||
let mut tokens = vec![];
|
||||
|
||||
for seq in vec_seq.clone() {
|
||||
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
|
||||
tokens.push(encoding.get_ids().to_vec());
|
||||
}
|
||||
|
||||
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
|
||||
// Pad the sequences to have the same length
|
||||
for token_vec in tokens.iter_mut() {
|
||||
let len_diff = max_len - token_vec.len();
|
||||
if len_diff > 0 {
|
||||
token_vec.extend(vec![pad_id; len_diff]);
|
||||
}
|
||||
}
|
||||
|
||||
let input_ids = Tensor::new(tokens, device)?;
|
||||
|
||||
Ok((input_ids, vec_seq))
|
||||
}
|
@ -72,8 +72,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())?
|
||||
.to_device(&device)?;
|
||||
let image =
|
||||
candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)?
|
||||
.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -284,11 +284,11 @@ impl MusicgenDecoder {
|
||||
};
|
||||
let embed_dim = cfg.vocab_size + 1;
|
||||
let embed_tokens = (0..cfg.num_codebooks)
|
||||
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
|
||||
.map(|i| embedding(embed_dim, h, vb.pp(format!("embed_tokens.{i}"))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
|
||||
.map(|i| MusicgenDecoderLayer::load(vb.pp(format!("layers.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
@ -341,7 +341,7 @@ impl MusicgenForCausalLM {
|
||||
let h = cfg.hidden_size;
|
||||
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
||||
let lm_heads = (0..cfg.num_codebooks)
|
||||
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
|
||||
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(format!("lm_heads.{i}"))))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
|
23
candle-examples/examples/parler-tts/README.md
Normal file
23
candle-examples/examples/parler-tts/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# candle-parler-tts
|
||||
|
||||
[Parler-TTS](https://huggingface.co/parler-tts/parler-tts-large-v1) is a large
|
||||
text-to-speech model with 2.2B parameters trained on ~45K hours of audio data.
|
||||
The voice can be controlled by a text prompt.
|
||||
|
||||
## Run an example
|
||||
|
||||
```bash
|
||||
cargo run --example parler-tts -r -- \
|
||||
--prompt "Hey, how are you doing today?"
|
||||
```
|
||||
|
||||
In order to specify some prompt for the voice, use the `--description` argument.
|
||||
```bash
|
||||
cargo run --example parler-tts -r -- \
|
||||
--prompt "Hey, how are you doing today?" \
|
||||
--description "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
|
||||
```
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/1b16aeac-70a3-4803-8589-4563279bba33
|
||||
|
BIN
candle-examples/examples/parler-tts/hello.mp4
Normal file
BIN
candle-examples/examples/parler-tts/hello.mp4
Normal file
Binary file not shown.
206
candle-examples/examples/parler-tts/main.rs
Normal file
206
candle-examples/examples/parler-tts/main.rs
Normal file
@ -0,0 +1,206 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::parler_tts::{Config, Model};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Display the token for the specified prompt.
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
#[arg(long, default_value = "Hey, how are you doing today?")]
|
||||
prompt: String,
|
||||
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
|
||||
)]
|
||||
description: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.0)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
seed: u64,
|
||||
|
||||
#[arg(long, default_value_t = 5000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Use f16 precision for all the computations rather than f32.
|
||||
#[arg(long)]
|
||||
f16: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value_t = 512)]
|
||||
max_steps: usize,
|
||||
|
||||
/// The output wav file.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
|
||||
#[arg(long, default_value = "large-v1")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "large-v1")]
|
||||
LargeV1,
|
||||
#[value(name = "mini-v1")]
|
||||
MiniV1,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::LargeV1 => "parler-tts/parler-tts-large-v1".to_string(),
|
||||
Which::MiniV1 => "parler-tts/parler-tts-mini-v1".to_string(),
|
||||
},
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(r) => r,
|
||||
None => "main".to_string(),
|
||||
};
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
model_id,
|
||||
hf_hub::RepoType::Model,
|
||||
revision,
|
||||
));
|
||||
let model_files = match args.model_file {
|
||||
Some(m) => vec![m.into()],
|
||||
None => match args.which {
|
||||
Which::MiniV1 => vec![repo.get("model.safetensors")?],
|
||||
Which::LargeV1 => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
let config = match args.config_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let tokenizer = match args.tokenizer_file {
|
||||
Some(m) => m.into(),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
|
||||
let mut model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let description_tokens = tokenizer
|
||||
.encode(args.description, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
|
||||
let prompt_tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
|
||||
let lp = candle_transformers::generation::LogitsProcessor::new(
|
||||
args.seed,
|
||||
Some(args.temperature),
|
||||
args.top_p,
|
||||
);
|
||||
println!("starting generation...");
|
||||
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
|
||||
println!("generated codes\n{codes}");
|
||||
let codes = codes.to_dtype(DType::I64)?;
|
||||
codes.save_safetensors("codes", "out.safetensors")?;
|
||||
let codes = codes.unsqueeze(0)?;
|
||||
let pcm = model
|
||||
.audio_encoder
|
||||
.decode_codes(&codes.to_device(&device)?)?;
|
||||
println!("{pcm}");
|
||||
let pcm = pcm.i((0, 0))?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
|
||||
|
||||
Ok(())
|
||||
}
|
12
candle-examples/examples/silero-vad/README.md
Normal file
12
candle-examples/examples/silero-vad/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# silero-vad: Voice Activity Detection
|
||||
|
||||
[Silero VAD (v5)](https://github.com/snakers4/silero-vad) detects voice activity in streaming audio.
|
||||
|
||||
This example uses the models available in the hugging face [onnx-community/silero-vad](https://huggingface.co/onnx-community/silero-vad).
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||
```
|
||||
|
199
candle-examples/examples/silero-vad/main.rs
Normal file
199
candle-examples/examples/silero-vad/main.rs
Normal file
@ -0,0 +1,199 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "silero")]
|
||||
Silero,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum SampleRate {
|
||||
#[value(name = "8000")]
|
||||
Sr8k,
|
||||
#[value(name = "16000")]
|
||||
Sr16k,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
input: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
sample_rate: SampleRate,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "silero")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
/// an iterator which reads consecutive frames of le i16 values from a reader
|
||||
struct I16Frames<R> {
|
||||
rdr: R,
|
||||
buf: Box<[u8]>,
|
||||
len: usize,
|
||||
eof: bool,
|
||||
}
|
||||
impl<R> I16Frames<R> {
|
||||
fn new(rdr: R, frame_size: usize) -> Self {
|
||||
I16Frames {
|
||||
rdr,
|
||||
buf: vec![0; frame_size * std::mem::size_of::<i16>()].into_boxed_slice(),
|
||||
len: 0,
|
||||
eof: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<R: std::io::Read> Iterator for I16Frames<R> {
|
||||
type Item = std::io::Result<Vec<f32>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.eof {
|
||||
return None;
|
||||
}
|
||||
self.len += match self.rdr.read(&mut self.buf[self.len..]) {
|
||||
Ok(0) => {
|
||||
self.eof = true;
|
||||
0
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
if self.eof || self.len == self.buf.len() {
|
||||
let buf = self.buf[..self.len]
|
||||
.chunks(2)
|
||||
.map(|bs| match bs {
|
||||
[a, b] => i16::from_le_bytes([*a, *b]),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.map(|i| i as f32 / i16::MAX as f32)
|
||||
.collect();
|
||||
self.len = 0;
|
||||
Some(Ok(buf))
|
||||
} else {
|
||||
self.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => std::path::PathBuf::from(model_id),
|
||||
None => match args.which {
|
||||
Which::Silero => hf_hub::api::sync::Api::new()?
|
||||
.model("onnx-community/silero-vad".into())
|
||||
.get("onnx/model.onnx")?,
|
||||
// TODO: candle-onnx doesn't support Int8 dtype
|
||||
// Which::SileroQuantized => hf_hub::api::sync::Api::new()?
|
||||
// .model("onnx-community/silero-vad".into())
|
||||
// .get("onnx/model_quantized.onnx")?,
|
||||
},
|
||||
};
|
||||
let (sample_rate, frame_size, context_size): (i64, usize, usize) = match args.sample_rate {
|
||||
SampleRate::Sr8k => (8000, 256, 32),
|
||||
SampleRate::Sr16k => (16000, 512, 64),
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = candle_onnx::read_file(model_id)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
struct State {
|
||||
frame_size: usize,
|
||||
sample_rate: Tensor,
|
||||
state: Tensor,
|
||||
context: Tensor,
|
||||
}
|
||||
|
||||
let mut state = State {
|
||||
frame_size,
|
||||
sample_rate: Tensor::new(sample_rate, &device)?,
|
||||
state: Tensor::zeros((2, 1, 128), DType::F32, &device)?,
|
||||
context: Tensor::zeros((1, context_size), DType::F32, &device)?,
|
||||
};
|
||||
let mut res = vec![];
|
||||
for chunk in I16Frames::new(std::io::stdin().lock(), state.frame_size) {
|
||||
let chunk = chunk.unwrap();
|
||||
if chunk.len() < state.frame_size {
|
||||
continue;
|
||||
}
|
||||
let next_context = Tensor::from_slice(
|
||||
&chunk[state.frame_size - context_size..],
|
||||
(1, context_size),
|
||||
&device,
|
||||
)?;
|
||||
let chunk = Tensor::from_vec(chunk, (1, state.frame_size), &device)?;
|
||||
let chunk = Tensor::cat(&[&state.context, &chunk], 1)?;
|
||||
let inputs = std::collections::HashMap::from_iter([
|
||||
("input".to_string(), chunk),
|
||||
("sr".to_string(), state.sample_rate.clone()),
|
||||
("state".to_string(), state.state.clone()),
|
||||
]);
|
||||
let out = candle_onnx::simple_eval(&model, inputs).unwrap();
|
||||
let out_names = &model.graph.as_ref().unwrap().output;
|
||||
let output = out.get(&out_names[0].name).unwrap().clone();
|
||||
state.state = out.get(&out_names[1].name).unwrap().clone();
|
||||
assert_eq!(state.state.dims(), &[2, 1, 128]);
|
||||
state.context = next_context;
|
||||
|
||||
let output = output.flatten_all()?.to_vec1::<f32>()?;
|
||||
assert_eq!(output.len(), 1);
|
||||
let output = output[0];
|
||||
println!("vad chunk prediction: {output}");
|
||||
res.push(output);
|
||||
}
|
||||
println!("calculated prediction in {:?}", start.elapsed());
|
||||
|
||||
let res_len = res.len() as f32;
|
||||
let prediction = res.iter().sum::<f32>() / res_len;
|
||||
println!("vad average prediction: {prediction}");
|
||||
Ok(())
|
||||
}
|
@ -123,7 +123,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
let padding = if pad != 0 { (size - 1) / 2 } else { 0 };
|
||||
let (bn, bias) = match b.parameters.get("batch_normalize") {
|
||||
Some(p) if p.parse::<usize>()? != 0 => {
|
||||
let bn = batch_norm(filters, 1e-5, vb.pp(&format!("batch_norm_{index}")))?;
|
||||
let bn = batch_norm(filters, 1e-5, vb.pp(format!("batch_norm_{index}")))?;
|
||||
(Some(bn), false)
|
||||
}
|
||||
Some(_) | None => (None, true),
|
||||
@ -135,9 +135,9 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
||||
dilation: 1,
|
||||
};
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||
} else {
|
||||
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
||||
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||
};
|
||||
let leaky = match activation {
|
||||
"leaky" => true,
|
||||
|
@ -161,7 +161,7 @@ impl C2f {
|
||||
let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?;
|
||||
let mut bottleneck = Vec::with_capacity(n);
|
||||
for idx in 0..n {
|
||||
let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?;
|
||||
let b = Bottleneck::load(vb.pp(format!("bottleneck.{idx}")), c, c, shortcut)?;
|
||||
bottleneck.push(b)
|
||||
}
|
||||
Ok(Self {
|
||||
|
@ -1,23 +1,42 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
/// Loads an image from disk using the image crate at the requested resolution.
|
||||
// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
|
||||
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
|
||||
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
||||
|
||||
/// Loads an image from disk using the image crate at the requested resolution,
|
||||
/// using the given std and mean parameters.
|
||||
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||
|
||||
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
res: usize,
|
||||
mean: &[f32; 3],
|
||||
std: &[f32; 3],
|
||||
) -> Result<Tensor> {
|
||||
let img = image::ImageReader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(res, res, image::imageops::FilterType::Triangle);
|
||||
.resize_to_fill(
|
||||
res as u32,
|
||||
res as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?;
|
||||
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate at the requested resolution.
|
||||
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: usize) -> Result<Tensor> {
|
||||
load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD)
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 224, 224). imagenet normalization is applied.
|
||||
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
@ -17,9 +17,12 @@ thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
||||
[dev-dependencies]
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
half = { version = "2.3.1", features = [
|
||||
"num-traits",
|
||||
"use-intrinsics",
|
||||
"rand_distr",
|
||||
] }
|
||||
anyhow = "1"
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
|
136
candle-metal-kernels/examples/metal_benchmarks.rs
Normal file
136
candle-metal-kernels/examples/metal_benchmarks.rs
Normal file
@ -0,0 +1,136 @@
|
||||
use anyhow::Result;
|
||||
use candle_metal_kernels::GemmDType;
|
||||
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
|
||||
use clap::{Parser, Subcommand};
|
||||
use half::f16;
|
||||
|
||||
fn run_gemm(f32: bool, n: usize) -> Result<()> {
|
||||
const WARMUP_ITERS: usize = 2;
|
||||
const MIN_DUR: f64 = 4.;
|
||||
|
||||
let device = metal::Device::system_default().unwrap();
|
||||
|
||||
let (b, m, n, k) = (1, n, n, n);
|
||||
let kernels = candle_metal_kernels::Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let options = metal::MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let (lhs, rhs) = if f32 {
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(&lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(&rhs) as u64,
|
||||
options,
|
||||
);
|
||||
(lhs, rhs)
|
||||
} else {
|
||||
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(&lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(&rhs) as u64,
|
||||
options,
|
||||
);
|
||||
(lhs, rhs)
|
||||
};
|
||||
let (dtype, name, sizeof) = if f32 {
|
||||
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
|
||||
} else {
|
||||
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
|
||||
};
|
||||
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
|
||||
|
||||
for mlx in [false, true] {
|
||||
let mut sum_dt = 0f64;
|
||||
let mut iters = 0usize;
|
||||
for idx in 0.. {
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let start_time = std::time::Instant::now();
|
||||
if mlx {
|
||||
candle_metal_kernels::call_mlx_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
} else {
|
||||
candle_metal_kernels::call_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&lhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&output,
|
||||
)?;
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
let dt = start_time.elapsed().as_secs_f64();
|
||||
if idx < WARMUP_ITERS {
|
||||
continue;
|
||||
}
|
||||
sum_dt += dt;
|
||||
iters += 1;
|
||||
if sum_dt > MIN_DUR {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
|
||||
let mlx = if mlx { "MLX" } else { "MFA" };
|
||||
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Gemm,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// The benchmark to be run.
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.task {
|
||||
Task::Gemm => {
|
||||
for f32 in [false, true] {
|
||||
for n in [512, 1024, 2048, 4096] {
|
||||
run_gemm(f32, n)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -105,7 +105,7 @@ kernel void FN_NAME##_strided( \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
|
||||
} \
|
||||
|
||||
|
||||
|
@ -1,921 +0,0 @@
|
||||
//
|
||||
// GEMM.metal
|
||||
// MetalFlashAttention
|
||||
//
|
||||
// Created by Philip Turner on 6/23/23.
|
||||
//
|
||||
#include <metal_stdlib>
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_EVENT
|
||||
#define __METAL_SIMDGROUP_EVENT
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %struct._simdgroup_event_t = type opaque
|
||||
//
|
||||
struct _simdgroup_event_t;
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, threadgroup void *, const device void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, device void *, const threadgroup void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
threadgroup void *, ulong, ulong, ulong2,
|
||||
const device void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
device void *, ulong, ulong, ulong2,
|
||||
const threadgroup void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: convergent nounwind
|
||||
// declare void
|
||||
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
|
||||
// local_unnamed_addr #3
|
||||
//
|
||||
void __metal_wait_simdgroup_events(
|
||||
int, thread _simdgroup_event_t**)
|
||||
__asm("air.wait_simdgroup_events");
|
||||
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
enum class simdgroup_async_copy_clamp_mode {
|
||||
clamp_to_zero = 0,
|
||||
clamp_to_edge = 1
|
||||
};
|
||||
|
||||
struct simdgroup_event {
|
||||
METAL_FUNC simdgroup_event() thread {}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
threadgroup T *dst,
|
||||
const device T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
reinterpret_cast<const device void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
device T *dst,
|
||||
const threadgroup T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
threadgroup T *dst,
|
||||
ushort dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const device T *src,
|
||||
uint src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false,
|
||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
ushort(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const device void *>(src),
|
||||
uint(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
static_cast<int>(clamp_mode));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
device T *dst,
|
||||
uint dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const threadgroup T *src,
|
||||
ushort src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
uint(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
ushort(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
0);
|
||||
}
|
||||
|
||||
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
|
||||
__metal_wait_simdgroup_events(
|
||||
count, reinterpret_cast<thread _simdgroup_event_t**>(events));
|
||||
}
|
||||
|
||||
private:
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
|
||||
//
|
||||
thread _simdgroup_event_t* event;
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
#endif
|
||||
|
||||
// -*- Metal -*-
|
||||
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
|
||||
// Copyright (c) 2023 Philip Turner. See MIT LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
#define __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
|
||||
// Contains C++ symbols accessible to a developer through automatic code
|
||||
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard
|
||||
// Library for consistency with other Metal code.
|
||||
|
||||
#if defined(__HAVE_SIMDGROUP_MATRIX__)
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
template <typename T>
|
||||
struct simdgroup_matrix_storage {
|
||||
typedef vec<T, 64> storage_type;
|
||||
|
||||
storage_type t;
|
||||
|
||||
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
|
||||
return reinterpret_cast<thread vec<T, 2>*>(&t);
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage() thread = default;
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
|
||||
*(this->thread_elements()) = thread_elements;
|
||||
}
|
||||
|
||||
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) {
|
||||
// https://patents.google.com/patent/US11256518B2
|
||||
ushort lane_id = thread_index_in_simdgroup;
|
||||
ushort quad_id = lane_id / 4;
|
||||
|
||||
constexpr ushort QUADRANT_SPAN_M = 4;
|
||||
constexpr ushort THREADS_PER_QUADRANT = 8;
|
||||
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
|
||||
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
|
||||
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
|
||||
|
||||
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
|
||||
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
|
||||
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
|
||||
|
||||
return ushort2(N_in_simd, M_in_simd);
|
||||
}
|
||||
|
||||
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
|
||||
} else {
|
||||
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
|
||||
} else {
|
||||
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: All load and store functions assume the X dimension is divisible by 2.
|
||||
|
||||
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
|
||||
} else {
|
||||
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
|
||||
} else {
|
||||
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
|
||||
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
|
||||
} else {
|
||||
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
|
||||
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
|
||||
} else {
|
||||
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
|
||||
if (!accumulate) {
|
||||
*(thread_elements()) = vec<T, 2>(0);
|
||||
}
|
||||
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
|
||||
}
|
||||
|
||||
// 'bfloat' is 'float' with the lower 16 bits set to garbage (BF15).
|
||||
|
||||
METAL_FUNC thread ushort4* thread_elements_bfloat() thread {
|
||||
thread float2* elements = thread_elements();
|
||||
return reinterpret_cast<thread ushort4*>(elements);
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage<float> unpack_bfloat() thread {
|
||||
ushort4 output;
|
||||
thread ushort2& elements = thread_elements();
|
||||
output.y = elements[0];
|
||||
output.w = elements[1];
|
||||
return simdgroup_matrix_storage(as_type<float2>(output));
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage<ushort> pack_bfloat() thread {
|
||||
thread ushort4* elements = thread_elements_bfloat();
|
||||
return simdgroup_matrix_storage(ushort2(elements->y, elements->w));
|
||||
}
|
||||
|
||||
METAL_FUNC void load_bfloat(const threadgroup ushort *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
thread_elements_bfloat()->y = src[matrix_origin.x * elements_per_row + matrix_origin.y];
|
||||
thread_elements_bfloat()->w = src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y];
|
||||
} else {
|
||||
thread_elements_bfloat()->zw = *reinterpret_cast<const threadgroup ushort2*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
|
||||
thread_elements_bfloat()->y = thread_elements_bfloat()->z;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void store_bfloat(threadgroup ushort *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).y;
|
||||
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).w;
|
||||
} else {
|
||||
*(thread_elements_bfloat()).z = *(thread_elements_bfloat()).y;
|
||||
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements_bfloat()).zw;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
#endif
|
||||
|
||||
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// MARK: - Function Constants
|
||||
|
||||
// Dimensions of each matrix.
|
||||
constant uint M [[function_constant(0)]];
|
||||
constant uint N [[function_constant(1)]];
|
||||
constant uint K [[function_constant(2)]];
|
||||
|
||||
// Whether each matrix is transposed.
|
||||
constant bool A_trans [[function_constant(10)]];
|
||||
constant bool B_trans [[function_constant(11)]];
|
||||
constant bool D_trans [[function_constant(13)]];
|
||||
constant uint A_leading_dim = A_trans ? M : K;
|
||||
constant uint B_leading_dim = B_trans ? K : N;
|
||||
|
||||
// Alpha and beta constants from BLAS.
|
||||
constant float alpha [[function_constant(20)]];
|
||||
constant float beta [[function_constant(21)]];
|
||||
|
||||
constant bool batched [[function_constant(100)]];
|
||||
constant bool fused_activation [[function_constant(101)]];
|
||||
constant bool fused_bias [[function_constant(50001)]]; // 102
|
||||
constant bool use_bias = is_function_constant_defined(fused_bias) ? fused_bias : false;
|
||||
constant bool use_activation_function = fused_activation && !fused_bias;
|
||||
constant bool use_activation = use_bias || use_activation_function;
|
||||
constant bool batched_activation_function = batched && use_activation_function;
|
||||
|
||||
constant ushort M_simd [[function_constant(200)]];
|
||||
constant ushort N_simd [[function_constant(201)]];
|
||||
constant ushort K_simd [[function_constant(202)]];
|
||||
|
||||
// Elide work on the edge when matrix dimension < SRAM block dimension.
|
||||
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
|
||||
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
|
||||
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
|
||||
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
|
||||
|
||||
constant ushort M_splits [[function_constant(210)]];
|
||||
constant ushort N_splits [[function_constant(211)]];
|
||||
|
||||
constant ushort M_group = M_simd * M_splits;
|
||||
constant ushort N_group = N_simd * N_splits;
|
||||
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
|
||||
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);
|
||||
|
||||
// There is no padding for M reads/writes.
|
||||
// There is no padding for N reads/writes.
|
||||
constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd);
|
||||
constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;
|
||||
|
||||
constant ushort A_sram_length = (M_simd / 8) * 1;
|
||||
constant ushort B_sram_length = 1 * (N_simd / 8);
|
||||
constant ushort A_block_length = M_group * K_simd;
|
||||
|
||||
// Threadgroup block must fit entire C accumulator and partial sums.
|
||||
constant ushort A_sram_offset = 0;
|
||||
constant ushort B_sram_offset = A_sram_offset + A_sram_length;
|
||||
constant ushort C_sram_offset = B_sram_offset + B_sram_length;
|
||||
constant ushort A_block_offset = 0;
|
||||
constant ushort B_block_offset = A_block_offset + A_block_length;
|
||||
|
||||
// MARK: - Utilities
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// A_sram[M_simd][8]
|
||||
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// A_sram[8][N_simd]
|
||||
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
|
||||
// C_sram[M_simd][N_simd]
|
||||
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void prefetch(threadgroup T *A_block, device T *A,
|
||||
ushort2 A_tile_src, uint2 A_offset,
|
||||
threadgroup T *B_block, device T *B,
|
||||
ushort2 B_tile_src, uint2 B_offset, uint k)
|
||||
{
|
||||
A_tile_src.x = min(uint(K_simd), K - k);
|
||||
B_tile_src.y = min(uint(K_simd), K - k);
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<T>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
// Rounded-up ceiling for the threadgroup block.
|
||||
const uint K_edge_floor = K - K_simd_unpadded;
|
||||
const uint K_edge_ceil = K_edge_floor + K_simd_padded;
|
||||
ushort K_padded;
|
||||
if (K_edge_floor == K_simd) {
|
||||
K_padded = K_simd;
|
||||
} else {
|
||||
K_padded = min(uint(K_simd), K_edge_ceil - k);
|
||||
}
|
||||
ushort2 A_tile_dst(K_padded, A_tile_src.y);
|
||||
ushort2 B_tile_dst(B_tile_src.x, K_padded);
|
||||
|
||||
simdgroup_event events[2];
|
||||
events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
|
||||
events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
|
||||
simdgroup_event::wait(2, events);
|
||||
}
|
||||
|
||||
// One iteration of the MACC loop, effectively k=8 iterations.
|
||||
template <typename T>
|
||||
METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage<T> *sram,
|
||||
const threadgroup T *A_block,
|
||||
const threadgroup T *B_block,
|
||||
bool accumulate = true)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
ushort2 origin(0, m);
|
||||
A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, 0);
|
||||
B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto A = A_sram(sram, ushort2(0, m));
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
C->multiply(*A, *B, accumulate);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void partial_store(thread simdgroup_matrix_storage<T> *sram,
|
||||
threadgroup T *C_block, bool is_k_summation)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
if (is_k_summation) {
|
||||
C_sram(sram, origin)->store(C_block, N_simd, origin);
|
||||
} else {
|
||||
C_sram(sram, origin)->store(C_block, N_group, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage<T> *sram,
|
||||
threadgroup T *C_block, bool is_k_summation)
|
||||
{
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
if (is_k_summation) {
|
||||
B->load(C_block, N_simd, origin);
|
||||
} else {
|
||||
B->load(C_block, N_group, origin);
|
||||
}
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto B = B_sram(sram, ushort2(n, 0));
|
||||
auto C = C_sram(sram, origin);
|
||||
if (is_k_summation) {
|
||||
C->thread_elements()[0] += B->thread_elements()[0];
|
||||
} else {
|
||||
float2 C_old = float2(B->thread_elements()[0]);
|
||||
float2 C_new = float2(C->thread_elements()[0]);
|
||||
C->thread_elements()[0] = vec<T, 2>(fast::fma(C_old, beta, C_new));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C,
|
||||
uint2 C_offset, bool is_store)
|
||||
{
|
||||
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
|
||||
min(uint(M_group), M - C_offset.y));
|
||||
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, C_offset);
|
||||
|
||||
if (is_store) {
|
||||
simdgroup_event event;
|
||||
event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile);
|
||||
simdgroup_event::wait(1, &event);
|
||||
} else {
|
||||
simdgroup_event event;
|
||||
event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile);
|
||||
simdgroup_event::wait(1, &event);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
|
||||
device T *C, bool m_is_edge, bool n_is_edge)
|
||||
{
|
||||
const ushort m_start = (m_is_edge) ? M_modulo : 0;
|
||||
const ushort n_start = (n_is_edge) ? N_modulo : 0;
|
||||
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
|
||||
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = m_start; m < m_end; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = n_start; n < n_end; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
C_sram(sram, origin)->store(C, N, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct activation_functor {
|
||||
using function = void(threadgroup T *C,
|
||||
device void *D,
|
||||
uint grid_index_in_batch,
|
||||
uint2 matrix_origin,
|
||||
ushort2 tile_dimensions,
|
||||
ushort lane_id);
|
||||
|
||||
typedef visible_function_table<function> function_table;
|
||||
};
|
||||
|
||||
// MARK: - Kernels
|
||||
|
||||
template <typename T>
|
||||
void _gemm_impl(device T *A [[buffer(0)]],
|
||||
device T *B [[buffer(1)]],
|
||||
device T *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup T *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<T>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
if (batched) {
|
||||
// TODO: Re-compute every inner loop iteration for FP64 accumulate.
|
||||
ulong3 offsets = matrix_offsets[0].xyz * gid.z;
|
||||
A = (device T*)((device uchar*)A + offsets[0]);
|
||||
B = (device T*)((device uchar*)B + offsets[1]);
|
||||
C = (device T*)((device uchar*)C + offsets[2]);
|
||||
}
|
||||
|
||||
simdgroup_matrix_storage<T> sram[1024];
|
||||
auto A_block = threadgroup_block + A_block_offset;
|
||||
auto B_block = threadgroup_block + B_block_offset;
|
||||
ushort2 sid(sidx % N_splits, sidx / N_splits);
|
||||
ushort2 offset_in_simd = simdgroup_matrix_storage<T>::offset(lane_id);
|
||||
|
||||
uint2 A_offset(0, gid.y * M_group);
|
||||
uint2 B_offset(gid.x * N_group, 0);
|
||||
{
|
||||
uint C_base_offset_x = B_offset.x + sid.x * N_simd;
|
||||
uint C_base_offset_y = A_offset.y + sid.y * M_simd;
|
||||
if (C_base_offset_x >= N || C_base_offset_y >= M) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x,
|
||||
sid.y * M_simd + offset_in_simd.y);
|
||||
|
||||
if (use_bias) {
|
||||
if (sidx == 0) {
|
||||
auto bias = (device T*)D;
|
||||
if (batched) {
|
||||
ulong offset = matrix_offsets[gid.z].w;
|
||||
bias = (device T*)((device uchar*)bias + offset);
|
||||
}
|
||||
|
||||
ushort bias_elements;
|
||||
if (is_function_constant_defined(D_trans) && D_trans) {
|
||||
bias += A_offset.y;
|
||||
bias_elements = min(uint(M_group), M - A_offset.y);
|
||||
} else {
|
||||
bias += B_offset.x;
|
||||
bias_elements = min(uint(N_group), N - B_offset.x);
|
||||
}
|
||||
|
||||
simdgroup_event event;
|
||||
event.async_copy(threadgroup_block, bias, bias_elements);
|
||||
simdgroup_event::wait(1, &event);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (is_function_constant_defined(D_trans) && D_trans) {
|
||||
auto bias = threadgroup_block + offset_in_group.y;
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto D = bias[m];
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
*(C->thread_elements()) = D;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto bias = threadgroup_block + offset_in_group.x;
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
auto D = *(threadgroup vec<T, 2>*)(bias + n);
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
auto C = C_sram(sram, ushort2(n, m));
|
||||
*(C->thread_elements()) = D;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
ushort2 A_tile_src;
|
||||
ushort2 B_tile_src;
|
||||
if (sidx == 0) {
|
||||
A_tile_src.y = min(uint(M_group), M - A_offset.y);
|
||||
B_tile_src.x = min(uint(N_group), N - B_offset.x);
|
||||
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, 0);
|
||||
}
|
||||
|
||||
if (K > K_simd && !use_bias) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_padded; n += 8) {
|
||||
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<T>(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint K_floor = 0; K_floor < K; K_floor += K_simd) {
|
||||
ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y);
|
||||
ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y);
|
||||
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_block_leading_dim, A_block_offset, A_trans);
|
||||
auto B_block_src = simdgroup_matrix_storage<T>::apply_offset(B_block, B_block_leading_dim, B_block_offset, B_trans);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = 0; k < K_simd_padded; k += 8) {
|
||||
bool accumulate = use_bias || !(K <= K_simd && k == 0);
|
||||
multiply_accumulate(sram, A_block_src, B_block_src, accumulate);
|
||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||
}
|
||||
|
||||
if (K_floor + K_simd < K) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = K_simd_padded; k < K_simd; k += 8) {
|
||||
multiply_accumulate(sram, A_block_src, B_block_src);
|
||||
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
|
||||
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
uint K_next = K_floor + K_simd;
|
||||
A_offset.x = K_next;
|
||||
B_offset.y = K_next;
|
||||
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (alpha != 1) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int m = 0; m < M_padded; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int n = 0; n < N_padded; n += 8) {
|
||||
C_sram(sram, ushort2(n, m))->thread_elements()[0] *= T(alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint2 C_offset(B_offset.x, A_offset.y);
|
||||
ushort2 C_block_offset = offset_in_group.xy;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (beta != 0) {
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, false);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_accumulate(sram, C_block, false);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (use_activation_function) {
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_store(sram, C_block, false);
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
uint grid_index_in_batch = (batched ? gid.z : 0);
|
||||
uint2 matrix_origin = C_offset + uint2(C_block_offset);
|
||||
matrix_origin &= ~7;
|
||||
ushort2 tile_dimensions(min(uint(N_group), N - matrix_origin.x),
|
||||
min(uint(M_group), M - matrix_origin.y));
|
||||
uint function_index = 0;
|
||||
if (batched_activation_function) {
|
||||
function_index = activation_function_offsets[gid.z];
|
||||
}
|
||||
table[function_index](C_block, D, grid_index_in_batch, matrix_origin, tile_dimensions, lane_id);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, true);
|
||||
}
|
||||
return;
|
||||
} else if ((M % 8 != 0) || (N % 8 != 0)) {
|
||||
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
|
||||
partial_store(sram, C_block, false);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sidx == 0) {
|
||||
async_access_accumulator(threadgroup_block, C, C_offset, true);
|
||||
}
|
||||
} else {
|
||||
uint2 matrix_origin = C_offset + uint2(C_block_offset);
|
||||
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, matrix_origin);
|
||||
store_accumulator(sram, C_src, false, false);
|
||||
|
||||
const uint M_edge_floor = M - M % M_simd;
|
||||
const uint N_edge_floor = N - N % N_simd;
|
||||
if (matrix_origin.y < M_edge_floor) {
|
||||
store_accumulator(sram, C_src, true, false);
|
||||
}
|
||||
if (matrix_origin.x < N_edge_floor) {
|
||||
store_accumulator(sram, C_src, false, true);
|
||||
if (matrix_origin.y < M_edge_floor) {
|
||||
store_accumulator(sram, C_src, true, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void hgemm(device half *A [[buffer(0)]],
|
||||
device half *B [[buffer(1)]],
|
||||
device half *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup half *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<half>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<half>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
||||
|
||||
kernel void sgemm(device float *A [[buffer(0)]],
|
||||
device float *B [[buffer(1)]],
|
||||
device float *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup float *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<float>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<float>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
kernel void bgemm(
|
||||
device bfloat *A [[buffer(0)]],
|
||||
device bfloat *B [[buffer(1)]],
|
||||
device bfloat *C [[buffer(2)]],
|
||||
device void *D [[buffer(3), function_constant(use_activation)]],
|
||||
|
||||
threadgroup bfloat *threadgroup_block [[threadgroup(0)]],
|
||||
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
typename activation_functor<bfloat>::function_table table [[buffer(11), function_constant(use_activation_function)]],
|
||||
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]])
|
||||
{
|
||||
_gemm_impl<bfloat>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
|
||||
}
|
||||
#endif
|
@ -11,34 +11,35 @@ pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split, EncoderProvider};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||
// const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const GEMM: &str = include_str!("gemm.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
Indexing,
|
||||
Unary,
|
||||
Binary,
|
||||
Ternary,
|
||||
Cast,
|
||||
Reduce,
|
||||
Mfa,
|
||||
Conv,
|
||||
Random,
|
||||
Gemm,
|
||||
Indexing,
|
||||
Mfa,
|
||||
Quantized,
|
||||
Random,
|
||||
Reduce,
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
@ -192,17 +193,18 @@ impl Kernels {
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
Source::Binary => BINARY,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Cast => CAST,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Gemm => MLX_GEMM,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Random => RANDOM,
|
||||
Source::Reduce => REDUCE,
|
||||
Source::Sort => SORT,
|
||||
Source::Mfa => GEMM,
|
||||
Source::Ternary => TERNARY,
|
||||
Source::Unary => UNARY,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,14 +219,22 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let compile_options = CompileOptions::new();
|
||||
compile_options.set_fast_math_enabled(true);
|
||||
|
||||
let source_content = self.get_library_source(source);
|
||||
let lib = device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
@ -2171,5 +2181,181 @@ pub fn call_arg_sort(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum GemmDType {
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_mlx_gemm(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
dtype: GemmDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
lhs_buffer: &Buffer,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
#[derive(Debug)]
|
||||
#[repr(C)]
|
||||
struct GemmParams {
|
||||
m: i32,
|
||||
n: i32,
|
||||
k: i32,
|
||||
lda: i32,
|
||||
ldb: i32,
|
||||
ldd: i32,
|
||||
tiles_n: i32,
|
||||
tiles_m: i32,
|
||||
batch_stride_a: isize,
|
||||
batch_stride_b: isize,
|
||||
batch_stride_d: isize,
|
||||
swizzle_log: i32,
|
||||
gemm_k_iterations_aligned: i32,
|
||||
batch_ndim: i32,
|
||||
}
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// lhs has shape b, m, k
|
||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||
// there is a single element.
|
||||
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||
(k as i32, false)
|
||||
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||
(m as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
// rhs has shape b, k, n
|
||||
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||
(n as i32, false)
|
||||
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||
(k as i32, true)
|
||||
} else {
|
||||
return Err(MetalKernelError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?;
|
||||
};
|
||||
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
|
||||
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(10, Value::Bool(/* has_batch */ b > 1)),
|
||||
(100, Value::Bool(/* use_out_source */ false)),
|
||||
(110, Value::Bool(/* do_axpby */ false)),
|
||||
(200, Value::Bool(/* align_m */ m % bm == 0)),
|
||||
(201, Value::Bool(/* align_n */ n % bn == 0)),
|
||||
(202, Value::Bool(/* align_k */ k % bk == 0)),
|
||||
(300, Value::Bool(/* do_gather */ false)),
|
||||
]));
|
||||
|
||||
let swizzle_log = 0;
|
||||
let tile = 1 << swizzle_log;
|
||||
let tn = n.div_ceil(bn);
|
||||
let tm = m.div_ceil(bm);
|
||||
let tn = tn * tile;
|
||||
let tm = tm.div_ceil(tile);
|
||||
|
||||
let batch_stride_a = if lhs_stride.len() > 2 {
|
||||
lhs_stride[lhs_stride.len() - 3]
|
||||
} else {
|
||||
m * k
|
||||
};
|
||||
let batch_stride_b = if rhs_stride.len() > 2 {
|
||||
rhs_stride[rhs_stride.len() - 3]
|
||||
} else {
|
||||
n * k
|
||||
};
|
||||
|
||||
let gemm_params = GemmParams {
|
||||
m: m as i32,
|
||||
n: n as i32,
|
||||
k: k as i32,
|
||||
lda,
|
||||
ldb,
|
||||
ldd: n as i32,
|
||||
tiles_n: tn as i32,
|
||||
tiles_m: tm as i32,
|
||||
swizzle_log,
|
||||
batch_stride_a: batch_stride_a as isize,
|
||||
batch_stride_b: batch_stride_b as isize,
|
||||
batch_stride_d: (m * n) as isize,
|
||||
batch_ndim: 1i32,
|
||||
gemm_k_iterations_aligned: (k / bk) as i32,
|
||||
};
|
||||
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
|
||||
|
||||
// TODO(laurent): generate the name
|
||||
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
|
||||
let name = match (dtype, a_trans, b_trans) {
|
||||
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
|
||||
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
|
||||
};
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(3, Some(output), 0);
|
||||
encoder.set_bytes(
|
||||
4,
|
||||
std::mem::size_of::<GemmParams>() as u64,
|
||||
&gemm_params as *const GemmParams as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
6, // batch_shape
|
||||
std::mem::size_of::<i32>() as u64,
|
||||
&(b as i32) as *const i32 as *const c_void,
|
||||
);
|
||||
encoder.set_bytes(
|
||||
7,
|
||||
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
|
||||
batch_strides.as_ptr() as *const c_void,
|
||||
);
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: tn as u64,
|
||||
height: tm as u64,
|
||||
depth: /* batch_size_out */ b as u64,
|
||||
};
|
||||
let group_size = MTLSize {
|
||||
width: 32,
|
||||
height: wn,
|
||||
depth: wm,
|
||||
};
|
||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
Binary file not shown.
1440
candle-metal-kernels/src/mlx_gemm.metal
Normal file
1440
candle-metal-kernels/src/mlx_gemm.metal
Normal file
File diff suppressed because it is too large
Load Diff
@ -329,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
|
||||
#[test]
|
||||
fn cast_f32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
@ -360,7 +360,7 @@ fn cast_f32() {
|
||||
|
||||
#[test]
|
||||
fn cast_f16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
@ -391,7 +391,7 @@ fn cast_f16() {
|
||||
|
||||
#[test]
|
||||
fn cast_bf16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
@ -422,7 +422,7 @@ fn cast_bf16() {
|
||||
|
||||
#[test]
|
||||
fn cast_u32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
@ -453,7 +453,7 @@ fn cast_u32() {
|
||||
|
||||
#[test]
|
||||
fn cast_u8() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
@ -484,7 +484,7 @@ fn cast_u8() {
|
||||
|
||||
#[test]
|
||||
fn cast_i64() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f64 = [1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
@ -911,7 +911,7 @@ fn softmax() {
|
||||
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
@ -922,7 +922,7 @@ fn softmax() {
|
||||
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
|
||||
);
|
||||
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
@ -1045,14 +1045,15 @@ fn where_cond_u32_f32() {
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_gemm<T: Clone>(
|
||||
name: &'static str,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: Vec<usize>,
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
rhs: &[T],
|
||||
rhs_stride: Vec<usize>,
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
@ -1079,10 +1080,10 @@ fn run_gemm<T: Clone>(
|
||||
&kernels,
|
||||
name,
|
||||
(b, m, n, k),
|
||||
&lhs_stride,
|
||||
lhs_stride,
|
||||
lhs_offset,
|
||||
&lhs,
|
||||
&rhs_stride,
|
||||
rhs_stride,
|
||||
rhs_offset,
|
||||
&rhs,
|
||||
&output,
|
||||
@ -1105,10 +1106,10 @@ fn gemm() {
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
@ -1125,10 +1126,10 @@ fn gemm() {
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
@ -1150,10 +1151,10 @@ fn gemm() {
|
||||
"sgemm",
|
||||
(1, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
12 * 4,
|
||||
);
|
||||
assert_eq!(
|
||||
@ -1162,25 +1163,27 @@ fn gemm() {
|
||||
);
|
||||
|
||||
// bgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"bgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
if false {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs_stride = vec![m * k, k, 1];
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_gemm(
|
||||
"bgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
|
||||
// hgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
@ -1192,10 +1195,10 @@ fn gemm() {
|
||||
"hgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
lhs_stride,
|
||||
&lhs_stride,
|
||||
0,
|
||||
&rhs,
|
||||
rhs_stride,
|
||||
&rhs_stride,
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
@ -1204,6 +1207,204 @@ fn gemm() {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_mlx_gemm<T: Clone>(
|
||||
dtype: GemmDType,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs: &[T],
|
||||
lhs_stride: &[usize],
|
||||
lhs_offset: usize,
|
||||
rhs: &[T],
|
||||
rhs_stride: &[usize],
|
||||
rhs_offset: usize,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
|
||||
let lhs = device.new_buffer_with_data(
|
||||
lhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(lhs) as u64,
|
||||
options,
|
||||
);
|
||||
let rhs = device.new_buffer_with_data(
|
||||
rhs.as_ptr() as *const core::ffi::c_void,
|
||||
std::mem::size_of_val(rhs) as u64,
|
||||
options,
|
||||
);
|
||||
let length = b * m * n;
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_mlx_gemm(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
lhs_stride,
|
||||
lhs_offset,
|
||||
&lhs,
|
||||
rhs_stride,
|
||||
rhs_offset,
|
||||
&rhs,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, length)
|
||||
}
|
||||
|
||||
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
|
||||
use rand::SeedableRng;
|
||||
use rand_distr::Distribution;
|
||||
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
|
||||
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
|
||||
|
||||
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
|
||||
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
|
||||
let v1: Vec<f32> = run_mlx_gemm(
|
||||
dtype,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[k * n, n, 1],
|
||||
0,
|
||||
);
|
||||
let v2: Vec<f32> = run_gemm(
|
||||
"sgemm",
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[k * n, n, 1],
|
||||
0,
|
||||
);
|
||||
for (a, b) in v1.iter().zip(v2.iter()) {
|
||||
let diff = (a - b).abs();
|
||||
assert_eq!((diff * 1e4).round(), 0.)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mlx_vs_mfa() {
|
||||
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
|
||||
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
|
||||
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mlx_gemm() {
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F32,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F32,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![
|
||||
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
|
||||
518.0, 548.0, 578.0
|
||||
]
|
||||
);
|
||||
|
||||
// OFFSET
|
||||
let (b, m, n, k) = (2, 2, 4, 3);
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F32,
|
||||
(1, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
12 * 4,
|
||||
);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||
);
|
||||
|
||||
// bgemm sanity test
|
||||
{
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::BF16,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_bf16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// hgemm sanity test
|
||||
let (b, m, n, k) = (1, 2, 4, 3);
|
||||
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
|
||||
let results = run_mlx_gemm(
|
||||
GemmDType::F16,
|
||||
(b, m, n, k),
|
||||
&lhs,
|
||||
&[m * k, k, 1],
|
||||
0,
|
||||
&rhs,
|
||||
&[n * k, n, 1],
|
||||
0,
|
||||
);
|
||||
assert_eq!(
|
||||
approx_f16(results, 4),
|
||||
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
@ -1278,7 +1479,7 @@ fn random() {
|
||||
variance.sqrt()
|
||||
}
|
||||
|
||||
let shape = vec![1024, 10];
|
||||
let shape = [1024, 10];
|
||||
|
||||
let length = shape.iter().product::<usize>();
|
||||
let seed = 299792458;
|
||||
@ -1634,7 +1835,7 @@ fn max_pool2d_f16() {
|
||||
&strides,
|
||||
"max_pool2d_f16",
|
||||
);
|
||||
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
@ -1654,7 +1855,7 @@ fn max_pool2d_f16() {
|
||||
&strides,
|
||||
"max_pool2d_f16",
|
||||
);
|
||||
let expected = vec![5.0, 7.0, 13.0, 15.0]
|
||||
let expected = [5.0, 7.0, 13.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
@ -1677,7 +1878,7 @@ fn max_pool2d_bf16() {
|
||||
&strides,
|
||||
"max_pool2d_bf16",
|
||||
);
|
||||
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
@ -1697,7 +1898,7 @@ fn max_pool2d_bf16() {
|
||||
&strides,
|
||||
"max_pool2d_bf16",
|
||||
);
|
||||
let expected = vec![5.0, 7.0, 13.0, 15.0]
|
||||
let expected = [5.0, 7.0, 13.0, 15.0]
|
||||
.iter()
|
||||
.map(|v| half::bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
@ -1816,7 +2017,7 @@ fn avg_pool2d_f16() {
|
||||
&strides,
|
||||
"avg_pool2d_f16",
|
||||
);
|
||||
let expected = vec![
|
||||
let expected = [
|
||||
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
||||
]
|
||||
.iter()
|
||||
@ -1841,7 +2042,7 @@ fn avg_pool2d_bf16() {
|
||||
&strides,
|
||||
"avg_pool2d_bf16",
|
||||
);
|
||||
let expected = vec![
|
||||
let expected = [
|
||||
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
||||
]
|
||||
.iter()
|
||||
@ -1979,14 +2180,14 @@ fn conv_transpose1d_f32() {
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_f16() {
|
||||
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
@ -2007,7 +2208,7 @@ fn conv_transpose1d_f16() {
|
||||
"conv_transpose1d_f16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
let expected = [1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
@ -2016,14 +2217,14 @@ fn conv_transpose1d_f16() {
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_bf16() {
|
||||
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
@ -2044,7 +2245,7 @@ fn conv_transpose1d_bf16() {
|
||||
"conv_transpose1d_bf16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
let expected = [1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -56,7 +56,7 @@ template <typename T> METAL_FUNC T gelu(T x) {
|
||||
T x_cube = x_sq * x;
|
||||
T alpha = x + static_cast<T>(0.044715) * x_cube;
|
||||
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
|
||||
}
|
||||
template <typename T> METAL_FUNC T relu(T in){
|
||||
if (in < 0) {
|
||||
@ -154,7 +154,6 @@ UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
UNARY_OP(sign)
|
||||
@ -164,6 +163,11 @@ UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
// tanh may create NaN on large values, e.g. 45 rather than outputing 1.
|
||||
// This has been an issue for the encodec example.
|
||||
UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided);
|
||||
UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
@ -185,7 +189,6 @@ BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
BFLOAT_UNARY_OP(sign)
|
||||
@ -193,5 +196,7 @@ BFLOAT_UNARY_OP(sigmoid)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
#endif
|
||||
|
@ -165,20 +165,25 @@ pub trait EncoderProvider {
|
||||
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
|
||||
fn encoder(&self) -> Self::Encoder<'_>;
|
||||
}
|
||||
|
||||
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
|
||||
pub struct WrappedEncoder<'a> {
|
||||
inner: &'a ComputeCommandEncoderRef,
|
||||
end_encoding_on_drop: bool,
|
||||
}
|
||||
|
||||
impl<'a> Drop for WrappedEncoder<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.0.end_encoding()
|
||||
if self.end_encoding_on_drop {
|
||||
self.inner.end_encoding()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
|
||||
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
||||
&self.0
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,8 +191,11 @@ impl EncoderProvider for &metal::CommandBuffer {
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
||||
WrappedEncoder(self.new_compute_command_encoder())
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
WrappedEncoder {
|
||||
inner: self.new_compute_command_encoder(),
|
||||
end_encoding_on_drop: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -195,7 +203,22 @@ impl EncoderProvider for &metal::CommandBufferRef {
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
|
||||
WrappedEncoder(self.new_compute_command_encoder())
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
WrappedEncoder {
|
||||
inner: self.new_compute_command_encoder(),
|
||||
end_encoding_on_drop: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderProvider for &ComputeCommandEncoderRef {
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
WrappedEncoder {
|
||||
inner: self,
|
||||
end_encoding_on_drop: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -55,6 +55,10 @@ pub struct LSTMState {
|
||||
}
|
||||
|
||||
impl LSTMState {
|
||||
pub fn new(h: Tensor, c: Tensor) -> Self {
|
||||
LSTMState { h, c }
|
||||
}
|
||||
|
||||
/// The hidden state vector, which is also the output of the LSTM.
|
||||
pub fn h(&self) -> &Tensor {
|
||||
&self.h
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.6.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.6.0" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.7.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.7.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -66,6 +66,18 @@ impl Attr for GraphProto {
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Vec<String> {
|
||||
const TYPE: AttributeType = AttributeType::Strings;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
let mut ret = vec![];
|
||||
for bytes in attr.strings.iter() {
|
||||
let s = String::from_utf8(bytes.clone()).map_err(candle::Error::wrap)?;
|
||||
ret.push(s);
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
@ -340,8 +352,15 @@ fn simple_eval_(
|
||||
"Pow" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
// HACK: current implementation of broadcast_pow cannot handle negative base,
|
||||
// so we use powf where we can, which *does* correctly handle negative base.
|
||||
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
|
||||
let output = input0.powf(exp as f64)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
} else {
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
}
|
||||
"Exp" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
@ -610,6 +629,18 @@ fn simple_eval_(
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = xs.normalize_axis(axis)?;
|
||||
|
||||
// index_select does not support negative indices, so normalize them
|
||||
// to positive indices.
|
||||
let indices = &{
|
||||
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
|
||||
let max = Tensor::new(xs.dims()[axis] as i64, indices.device())?
|
||||
.to_dtype(indices.dtype())?;
|
||||
let mask = indices.lt(&zeros)?;
|
||||
mask.to_dtype(indices.dtype())?
|
||||
.broadcast_mul(&max)?
|
||||
.add(&indices)?
|
||||
};
|
||||
|
||||
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||
// tensor directly, but candle does not support tensor indexing at the moment, so
|
||||
// some workarounds must be done.
|
||||
@ -1310,6 +1341,233 @@ fn simple_eval_(
|
||||
.broadcast_add(&c.broadcast_mul(&beta)?)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"LSTM" => {
|
||||
let direction = get_attr_opt(node, "direction")?.unwrap_or("forward");
|
||||
if direction != "forward" {
|
||||
bail!("LSTM currently only supports direction == \"forward\"");
|
||||
}
|
||||
let num_directions = if direction == "bidirectional" { 2 } else { 1 };
|
||||
let hidden_size: i64 = get_attr(node, "hidden_size").copied()?;
|
||||
let input_forget = get_attr_opt(node, "input_forget")?.copied().unwrap_or(0);
|
||||
if input_forget != 0 {
|
||||
bail!("LSTM currently only supports input_forget == 0");
|
||||
}
|
||||
let activations_default = vec![
|
||||
"Sigmoid".to_string(),
|
||||
"Tanh".to_string(),
|
||||
"Tanh".to_string(),
|
||||
];
|
||||
let activations = get_attr_opt_owned::<Vec<String>>(node, "activations")?
|
||||
.unwrap_or(activations_default.clone());
|
||||
if activations != activations_default {
|
||||
bail!("LSTM currently only supports default activations ({activations_default:?})");
|
||||
}
|
||||
// activation_alpha and activation_beta don't apply to (Sigmoid, Tanh, Tanh) so ignoring them is okay
|
||||
if get_attr_opt::<f32>(node, "clip")?.is_some() {
|
||||
bail!("LSTM does not currently support clip attribute");
|
||||
}
|
||||
|
||||
// The shape format of inputs X, initial_h and outputs Y, Y_h.
|
||||
// If 0, the following shapes are expected:
|
||||
// X.shape = [seq_length, batch_size, input_size],
|
||||
// Y.shape = [seq_length, num_directions, batch_size, hidden_size],
|
||||
// initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size].
|
||||
// If 1, the following shapes are expected:
|
||||
// X.shape = [batch_size, seq_length, input_size],
|
||||
// Y.shape = [batch_size, seq_length, num_directions, hidden_size],
|
||||
// initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].
|
||||
let layout = get_attr_opt(node, "layout")?.copied().unwrap_or(0);
|
||||
if layout != 0 {
|
||||
bail!("LSTM currently only supports layout == 0");
|
||||
}
|
||||
|
||||
// The input sequences packed (and potentially padded) into one 3-D tensor
|
||||
// with the shape of `[seq_length, batch_size, input_size]`.
|
||||
let x = get(&node.input[0])?;
|
||||
// XXX: depends on layout
|
||||
let (seq_length, batch_size, input_size) = x.dims3()?;
|
||||
// The weight tensor for the gates.
|
||||
// Concatenation of `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0.
|
||||
// The tensor has shape `[num_directions, 4*hidden_size, input_size]`.
|
||||
let w = get(&node.input[1])?;
|
||||
// The recurrence weight tensor.
|
||||
// Concatenation of `R[iofc]` and `RB[iofc]` (if bidirectional) along dimension 0.
|
||||
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
|
||||
let r = get(&node.input[2])?;
|
||||
|
||||
let get_opt = |i: usize| {
|
||||
node.input
|
||||
.get(i)
|
||||
.filter(|s: &&String| !s.is_empty())
|
||||
.map(|s| get(s))
|
||||
};
|
||||
|
||||
// The bias tensor for input gate.
|
||||
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
|
||||
// This tensor has shape `[num_directions, 8*hidden_size]`.
|
||||
// Optional: If not specified - assumed to be 0.
|
||||
let b_default: Tensor;
|
||||
let b = match get_opt(3) {
|
||||
Some(n) => n?,
|
||||
None => {
|
||||
b_default = Tensor::zeros(
|
||||
(num_directions, 8 * hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&b_default
|
||||
}
|
||||
};
|
||||
|
||||
// Optional tensor specifying lengths of the sequences in a batch.
|
||||
// If not specified - assumed all sequences in the batch to have length `seq_length`.
|
||||
// It has shape `[batch_size]`.
|
||||
let seq_lens_default: Tensor;
|
||||
let seq_lens = match get_opt(4) {
|
||||
Some(n) => n?,
|
||||
None => {
|
||||
seq_lens_default =
|
||||
Tensor::full(seq_length as i64, (batch_size,), x.device())?;
|
||||
&seq_lens_default
|
||||
}
|
||||
};
|
||||
let seq_lens_is_default =
|
||||
(seq_lens.to_vec1::<i64>()?.iter()).all(|e| *e as usize == seq_length);
|
||||
if !seq_lens_is_default {
|
||||
bail!("LSTM currently only supports default value of seq_lens");
|
||||
}
|
||||
|
||||
// Optional initial value of the hidden. If not specified - assumed to be 0.
|
||||
// It has shape `[num_directions, batch_size, hidden_size]`.
|
||||
let initial_h_default: Tensor;
|
||||
let initial_h = match get_opt(5) {
|
||||
Some(n) => n?,
|
||||
_ => {
|
||||
initial_h_default = Tensor::zeros(
|
||||
(num_directions, batch_size, hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&initial_h_default
|
||||
}
|
||||
};
|
||||
|
||||
// Optional initial value of the cell.
|
||||
// If not specified - assumed to be 0.
|
||||
// It has shape `[num_directions, batch_size, hidden_size]`.
|
||||
let initial_c_default: Tensor;
|
||||
let initial_c = match node.input.get(6) {
|
||||
Some(n) if !n.is_empty() => get(n)?,
|
||||
_ => {
|
||||
initial_c_default = Tensor::zeros(
|
||||
(num_directions, batch_size, hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&initial_c_default
|
||||
}
|
||||
};
|
||||
|
||||
// The weight tensor for peepholes.
|
||||
// Concatenation of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0.
|
||||
// It has shape `[num_directions, 3*hidde_size]`. Optional: If not specified - assumed to be 0.
|
||||
let p_default = Tensor::zeros(
|
||||
(num_directions, 3 * hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
let p = get_opt(7).unwrap_or(Ok(&p_default))?;
|
||||
let p_is_zeros = (p.to_vec2::<f32>()?.iter()).all(|v| v.iter().all(|e| *e == 0.0));
|
||||
if !p_is_zeros {
|
||||
bail!(
|
||||
"LSTM currently only supports default value of p (a Tensor of all zeroes)"
|
||||
);
|
||||
}
|
||||
|
||||
// these all have [num_directions, ...] shapes
|
||||
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
|
||||
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
|
||||
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
|
||||
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
|
||||
let wb = b.index_select(&idx_wb, 0)?;
|
||||
let rb = b.index_select(&idx_rb, 0)?;
|
||||
let c = initial_c.get(0)?;
|
||||
let h = initial_h.get(0)?;
|
||||
|
||||
// w, r, wb, rb are all iofc but lstm expects ifco
|
||||
// so we need to move some stuff around
|
||||
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
|
||||
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
|
||||
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
|
||||
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
|
||||
let w = w.index_select(&idx_ifco, 0)?;
|
||||
let r = r.index_select(&idx_ifco, 0)?;
|
||||
let wb = wb.index_select(&idx_ifco, 0)?;
|
||||
let rb = rb.index_select(&idx_ifco, 0)?;
|
||||
let vmap = candle_nn::VarMap::new();
|
||||
vmap.data().lock().unwrap().extend([
|
||||
("weight_ih_l0".to_string(), candle::Var::from_tensor(&w)?),
|
||||
("weight_hh_l0".to_string(), candle::Var::from_tensor(&r)?),
|
||||
("bias_ih_l0".to_string(), candle::Var::from_tensor(&wb)?),
|
||||
("bias_hh_l0".to_string(), candle::Var::from_tensor(&rb)?),
|
||||
]);
|
||||
use candle_nn::rnn::RNN as _;
|
||||
let lstm = candle_nn::rnn::lstm(
|
||||
input_size,
|
||||
hidden_size as usize,
|
||||
candle_nn::rnn::LSTMConfig::default(),
|
||||
candle_nn::VarBuilder::from_varmap(&vmap, w.dtype(), w.device()),
|
||||
)?;
|
||||
|
||||
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
|
||||
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
for t in 0..seq_length {
|
||||
let x = x.get(t)?;
|
||||
lstm_state = lstm.step(&x, &lstm_state)?;
|
||||
if let Some(h_acc) = &mut h_acc {
|
||||
h_acc.push(lstm_state.clone());
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
|
||||
if let Some(name) = node.output.get(0) {
|
||||
let h_acc = h_acc.as_ref().unwrap();
|
||||
let h_acc = lstm.states_to_tensor(h_acc)?;
|
||||
let h_acc = h_acc.reshape((
|
||||
seq_length,
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?;
|
||||
values.insert(name.clone(), h_acc);
|
||||
}
|
||||
if let Some(name) = node.output.get(1) {
|
||||
values.insert(
|
||||
name.clone(),
|
||||
lstm_state.h().reshape((
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?,
|
||||
);
|
||||
}
|
||||
if let Some(name) = node.output.get(2) {
|
||||
values.insert(
|
||||
name.clone(),
|
||||
lstm_state.c().reshape((
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -6,11 +6,13 @@ extern crate accelerate_src;
|
||||
|
||||
use candle::test_utils::to_vec2_round;
|
||||
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::eval::Value;
|
||||
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
|
||||
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use candle_onnx::simple_eval;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const INPUT_X: &str = "x";
|
||||
@ -3514,3 +3516,467 @@ fn test_slice() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lstm() -> Result<()> {
|
||||
// values generated from pytorch, so at least it's close enough to what pytorch does
|
||||
/*
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)
|
||||
|
||||
import torch
|
||||
|
||||
rand_gen = torch.Generator()
|
||||
rand_gen.manual_seed(1)
|
||||
input_size = 3
|
||||
hidden_size = 5
|
||||
batch_size = 1
|
||||
sequence_length = 4
|
||||
number_directions = 1
|
||||
rnn = torch.nn.LSTM(input_size,hidden_size)
|
||||
weight_ih_l0 = torch.randn(rnn.weight_ih_l0.shape, generator=rand_gen)
|
||||
weight_hh_l0 = torch.randn(rnn.weight_hh_l0.shape, generator=rand_gen)
|
||||
bias_ih_l0 = torch.randn(rnn.bias_ih_l0.shape, generator=rand_gen)
|
||||
bias_hh_l0 = torch.randn(rnn.bias_hh_l0.shape, generator=rand_gen)
|
||||
rnn.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
|
||||
rnn.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
|
||||
rnn.bias_ih_l0 = torch.nn.Parameter(bias_ih_l0)
|
||||
rnn.bias_hh_l0 = torch.nn.Parameter(bias_hh_l0)
|
||||
input = torch.randn(sequence_length, batch_size, input_size, generator=rand_gen)
|
||||
h0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
|
||||
c0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
|
||||
output, (hn, cn) = rnn(input, (h0, c0))
|
||||
|
||||
def fmt_tensor(t):
|
||||
return "Tensor::from_vec::<_, f32>(vec!"+ str(t.flatten().tolist()) + ", (" + "".join([str(n)+"," for n in t.shape])+"), &Device::Cpu)?"
|
||||
|
||||
print("let input_size = ", input_size, ";")
|
||||
print("let hidden_size = ", hidden_size, ";")
|
||||
print("let batch_size = ", batch_size, ";")
|
||||
print("let sequence_length = ", sequence_length, ";")
|
||||
print("let number_directions = ", number_directions, ";")
|
||||
print("let weight_ih_l0 = ", fmt_tensor(rnn.weight_ih_l0), ";")
|
||||
print("let weight_hh_l0 = ", fmt_tensor(rnn.weight_hh_l0), ";")
|
||||
print("let bias_ih_l0 = ", fmt_tensor(rnn.bias_ih_l0), ";")
|
||||
print("let bias_hh_l0 = ", fmt_tensor(rnn.bias_hh_l0), ";")
|
||||
print("let input = ", fmt_tensor(input), ";")
|
||||
print("let h0 = ", fmt_tensor(h0), ";")
|
||||
print("let c0 = ", fmt_tensor(c0), ";")
|
||||
print("let output = ", fmt_tensor(output), ";")
|
||||
print("let hn = ", fmt_tensor(hn), ";")
|
||||
print("let cn = ", fmt_tensor(cn), ";")
|
||||
*/
|
||||
let input_size = 3;
|
||||
let hidden_size = 5;
|
||||
let batch_size = 1;
|
||||
let sequence_length = 4;
|
||||
let number_directions = 1;
|
||||
let weight_ih_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
-1.5255959033966064,
|
||||
-0.7502318024635315,
|
||||
-0.6539809107780457,
|
||||
-1.6094847917556763,
|
||||
-0.1001671776175499,
|
||||
-0.6091889142990112,
|
||||
-0.9797722697257996,
|
||||
-1.6090962886810303,
|
||||
-0.7121446132659912,
|
||||
0.30372199416160583,
|
||||
-0.777314305305481,
|
||||
-0.25145524740219116,
|
||||
-0.22227048873901367,
|
||||
1.6871134042739868,
|
||||
0.22842517495155334,
|
||||
0.46763551235198975,
|
||||
-0.6969724297523499,
|
||||
-1.1607614755630493,
|
||||
0.6995424032211304,
|
||||
0.1990816295146942,
|
||||
0.8656923770904541,
|
||||
0.2444038987159729,
|
||||
-0.6629113554954529,
|
||||
0.8073082566261292,
|
||||
1.1016806364059448,
|
||||
-0.1759360432624817,
|
||||
-2.2455577850341797,
|
||||
-1.4464579820632935,
|
||||
0.0611552819609642,
|
||||
-0.6177444458007812,
|
||||
-0.7980698347091675,
|
||||
-0.13162320852279663,
|
||||
1.8793457746505737,
|
||||
-0.07213178277015686,
|
||||
0.15777060389518738,
|
||||
-0.7734549045562744,
|
||||
0.1990565061569214,
|
||||
0.04570277780294418,
|
||||
0.15295691788196564,
|
||||
-0.47567880153656006,
|
||||
-0.11101982742547989,
|
||||
0.2927352488040924,
|
||||
-0.1578451544046402,
|
||||
-0.028787139803171158,
|
||||
0.4532545804977417,
|
||||
1.1421611309051514,
|
||||
0.2486107051372528,
|
||||
-1.7754007577896118,
|
||||
-0.025502461940050125,
|
||||
-1.023330569267273,
|
||||
-0.5961851477622986,
|
||||
-1.0055307149887085,
|
||||
0.42854228615760803,
|
||||
1.4760777950286865,
|
||||
-1.7868678569793701,
|
||||
1.610317587852478,
|
||||
-0.703956663608551,
|
||||
-0.18526579439640045,
|
||||
-0.9962350726127625,
|
||||
-0.8312552571296692,
|
||||
],
|
||||
(20, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let weight_hh_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.4099724292755127,
|
||||
0.4084506630897522,
|
||||
0.25786539912223816,
|
||||
1.095021367073059,
|
||||
-0.5064865946769714,
|
||||
0.09977540373802185,
|
||||
-0.653973400592804,
|
||||
0.731693685054779,
|
||||
-1.456732988357544,
|
||||
1.6089353561401367,
|
||||
0.09376997500658035,
|
||||
-1.2597490549087524,
|
||||
0.25463348627090454,
|
||||
-0.5019572973251343,
|
||||
-1.041200041770935,
|
||||
0.7322672009468079,
|
||||
1.3075355291366577,
|
||||
-1.1627987623214722,
|
||||
0.11963611096143723,
|
||||
-0.1631353348493576,
|
||||
0.6614453196525574,
|
||||
1.1899205446243286,
|
||||
0.8165339231491089,
|
||||
-0.9135236144065857,
|
||||
-0.3538065254688263,
|
||||
0.7639270424842834,
|
||||
-0.5889506936073303,
|
||||
-0.7635973691940308,
|
||||
1.3352056741714478,
|
||||
0.6042736172676086,
|
||||
-0.10344208031892776,
|
||||
-0.15121692419052124,
|
||||
1.2465683221817017,
|
||||
0.505721390247345,
|
||||
0.9505112171173096,
|
||||
1.2966482639312744,
|
||||
0.873796284198761,
|
||||
-0.5602594017982483,
|
||||
1.2857844829559326,
|
||||
0.8168238401412964,
|
||||
-1.464799404144287,
|
||||
-1.2629283666610718,
|
||||
1.122018814086914,
|
||||
1.5663341283798218,
|
||||
2.558138370513916,
|
||||
-0.23336388170719147,
|
||||
-0.013472129590809345,
|
||||
1.8606348037719727,
|
||||
1.549620509147644,
|
||||
0.34762924909591675,
|
||||
0.09300802648067474,
|
||||
0.6147403120994568,
|
||||
0.7123645544052124,
|
||||
-1.7765072584152222,
|
||||
0.3538645803928375,
|
||||
1.1996132135391235,
|
||||
-0.7122589349746704,
|
||||
-0.620034396648407,
|
||||
-0.22813494503498077,
|
||||
-0.7892746329307556,
|
||||
-1.6111117601394653,
|
||||
-1.8716129064559937,
|
||||
0.5430836081504822,
|
||||
0.6606786251068115,
|
||||
0.270527720451355,
|
||||
0.5596919655799866,
|
||||
-0.31839630007743835,
|
||||
1.5117206573486328,
|
||||
-1.363267183303833,
|
||||
-0.9832196235656738,
|
||||
1.5112667083740234,
|
||||
0.6418707370758057,
|
||||
-0.7474458813667297,
|
||||
-0.923438549041748,
|
||||
0.5733984112739563,
|
||||
-0.10929951071739197,
|
||||
0.5181121230125427,
|
||||
0.10653535276651382,
|
||||
0.26924076676368713,
|
||||
1.3247679471969604,
|
||||
0.037456899881362915,
|
||||
-0.6378393173217773,
|
||||
-0.8147554397583008,
|
||||
-0.6895065307617188,
|
||||
0.8436542749404907,
|
||||
1.1657012701034546,
|
||||
0.5269321799278259,
|
||||
1.6192532777786255,
|
||||
-0.963976263999939,
|
||||
0.14152038097381592,
|
||||
-0.1636609584093094,
|
||||
-0.3582225739955902,
|
||||
1.7222793102264404,
|
||||
-0.3035756051540375,
|
||||
0.23887419700622559,
|
||||
1.3440011739730835,
|
||||
0.1032256931066513,
|
||||
1.1003541946411133,
|
||||
-0.3416801989078522,
|
||||
0.947338879108429,
|
||||
],
|
||||
(20, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let bias_ih_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
-0.568515956401825,
|
||||
0.8375961780548096,
|
||||
1.783660650253296,
|
||||
-0.1954246610403061,
|
||||
0.235193133354187,
|
||||
1.9142433404922485,
|
||||
1.8364111185073853,
|
||||
1.324532389640808,
|
||||
-0.07051458209753036,
|
||||
0.34697940945625305,
|
||||
-0.653679609298706,
|
||||
1.5586202144622803,
|
||||
0.2185661494731903,
|
||||
-0.5743072628974915,
|
||||
1.4571250677108765,
|
||||
1.7709556818008423,
|
||||
-2.0172998905181885,
|
||||
0.42350319027900696,
|
||||
0.5730220079421997,
|
||||
-1.7962429523468018,
|
||||
],
|
||||
(20,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let bias_hh_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
1.2470403909683228,
|
||||
1.2738511562347412,
|
||||
0.3909492492675781,
|
||||
0.387210488319397,
|
||||
0.14440394937992096,
|
||||
0.7771684527397156,
|
||||
-2.3381125926971436,
|
||||
-0.829120397567749,
|
||||
1.1661391258239746,
|
||||
1.4786574840545654,
|
||||
0.26760873198509216,
|
||||
0.7561198472976685,
|
||||
-0.5873361229896545,
|
||||
-2.061920642852783,
|
||||
0.4304734766483307,
|
||||
0.3376566171646118,
|
||||
-0.3437853455543518,
|
||||
-0.6172260642051697,
|
||||
1.2529692649841309,
|
||||
-0.05141742154955864,
|
||||
],
|
||||
(20,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let input = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.6472128033638,
|
||||
-0.04116716980934143,
|
||||
-0.17749308049678802,
|
||||
-0.500039279460907,
|
||||
0.8672749400138855,
|
||||
-0.27319222688674927,
|
||||
-0.4607681334018707,
|
||||
-0.0990937128663063,
|
||||
0.47284480929374695,
|
||||
1.0049484968185425,
|
||||
-0.2871420383453369,
|
||||
-1.1618621349334717,
|
||||
],
|
||||
(4, 1, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let h0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.02758178487420082,
|
||||
0.5652382373809814,
|
||||
-0.011487378738820553,
|
||||
0.6706400513648987,
|
||||
-0.4929250478744507,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let c0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
1.505028486251831,
|
||||
-2.32635498046875,
|
||||
1.6168899536132812,
|
||||
-0.9026237726211548,
|
||||
0.17366823554039001,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let output = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.5956016778945923,
|
||||
-0.01723279245197773,
|
||||
0.11035571992397308,
|
||||
-0.49323174357414246,
|
||||
0.047632161527872086,
|
||||
0.6358451843261719,
|
||||
0.040328118950128555,
|
||||
-0.3788611590862274,
|
||||
-0.7464339733123779,
|
||||
0.20080909132957458,
|
||||
0.5840265154838562,
|
||||
0.1453288197517395,
|
||||
-0.7345298528671265,
|
||||
-0.5214304327964783,
|
||||
0.21903817355632782,
|
||||
0.7420451641082764,
|
||||
0.31943878531455994,
|
||||
-0.04726646468043327,
|
||||
-0.2823849618434906,
|
||||
0.2713133990764618,
|
||||
],
|
||||
(4, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let hn = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.7420451641082764,
|
||||
0.31943878531455994,
|
||||
-0.04726646468043327,
|
||||
-0.2823849618434906,
|
||||
0.2713133990764618,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let cn = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.9630558490753174,
|
||||
1.0033069849014282,
|
||||
-1.754899024963379,
|
||||
-1.5967122316360474,
|
||||
0.8252924680709839,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
// end of generated values
|
||||
|
||||
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "LSTM".to_string(),
|
||||
name: "LSTM_test".to_string(),
|
||||
attribute: vec![AttributeProto {
|
||||
name: "hidden_size".to_string(),
|
||||
r#type: AttributeType::Int.into(),
|
||||
i: hidden_size as i64,
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
input: vec![
|
||||
"input".to_string(),
|
||||
"w".to_string(),
|
||||
"r".to_string(),
|
||||
"b".to_string(), // b
|
||||
"".to_string(), // seq_lens
|
||||
"h".to_string(),
|
||||
"c".to_string(),
|
||||
],
|
||||
output: vec!["output".to_string(), "hn".to_string(), "cn".to_string()],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
input: ["input", "w", "r", "b", "h", "c"]
|
||||
.into_iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
output: ["output", "hn", "cn"]
|
||||
.into_iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
..GraphProto::default()
|
||||
}));
|
||||
// pytorch stores weight and bias as [ifco] but we want it as [iofc]
|
||||
// so we need to re-arrange the tensors a bit
|
||||
let idx_iofc = {
|
||||
let stride = hidden_size as i64;
|
||||
let dev = weight_ih_l0.device();
|
||||
let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?;
|
||||
let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?;
|
||||
let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?;
|
||||
let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?;
|
||||
|
||||
Tensor::cat(&[&idx_i, &idx_o, &idx_f, &idx_g], 0)?
|
||||
};
|
||||
let w = weight_ih_l0.index_select(&idx_iofc, 0)?;
|
||||
let w = w.reshape((number_directions, 4 * hidden_size, input_size))?;
|
||||
let r = weight_hh_l0.index_select(&idx_iofc, 0)?;
|
||||
let r = r.reshape((number_directions, 4 * hidden_size, hidden_size))?;
|
||||
let wb = bias_ih_l0.index_select(&idx_iofc, 0)?;
|
||||
let rb = bias_hh_l0.index_select(&idx_iofc, 0)?;
|
||||
let b = Tensor::cat(&[wb, rb], 0)?.reshape((number_directions, 8 * hidden_size))?;
|
||||
let output = output.reshape((sequence_length, number_directions, batch_size, hidden_size))?;
|
||||
let result = simple_eval(
|
||||
&model,
|
||||
HashMap::from_iter([
|
||||
("input".to_string(), input),
|
||||
("w".to_string(), w),
|
||||
("r".to_string(), r),
|
||||
("b".to_string(), b),
|
||||
("h".to_string(), h0),
|
||||
("c".to_string(), c0),
|
||||
]),
|
||||
)?;
|
||||
let actual_output = result.get("output").unwrap();
|
||||
assert_eq!(output.dims(), actual_output.dims());
|
||||
let actual_hn = result.get("hn").unwrap();
|
||||
assert_eq!(hn.dims(), actual_hn.dims());
|
||||
let actual_cn = result.get("cn").unwrap();
|
||||
assert_eq!(cn.dims(), actual_cn.dims());
|
||||
let diff_close_enough = |a: &Tensor, b| -> Result<_> {
|
||||
let diffs = a.sub(b)?.flatten_all()?.to_vec1::<f32>()?;
|
||||
Ok(diffs.iter().all(|f| f.abs() < 0.0001))
|
||||
};
|
||||
assert!(
|
||||
diff_close_enough(&output, &actual_output)?,
|
||||
"output did not match expected\n{actual_output}\n{output}",
|
||||
);
|
||||
assert!(
|
||||
diff_close_enough(&hn, &actual_hn)?,
|
||||
"hn did not match expected\n{actual_hn}\n{hn}",
|
||||
);
|
||||
assert!(
|
||||
diff_close_enough(&cn, &actual_cn)?,
|
||||
"cn did not match expected\n{actual_cn}\n{cn}",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
589
candle-transformers/src/models/based.rs
Normal file
589
candle-transformers/src/models/based.rs
Normal file
@ -0,0 +1,589 @@
|
||||
//! Based from the Stanford Hazy Research group.
|
||||
//!
|
||||
//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024
|
||||
//! <https://arxiv.org/abs/2402.18668>
|
||||
|
||||
//! Original code:
|
||||
//! https://github.com/HazyResearch/based
|
||||
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig,
|
||||
Func, Linear, RmsNorm, VarBuilder,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LinearAttentionFeatureMapConfig {
|
||||
input_dim: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LinearAttentionConfig {
|
||||
num_heads: usize,
|
||||
feature_dim: usize,
|
||||
feature_map: LinearAttentionFeatureMapConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct SlidingWindowAttentionConfig {
|
||||
num_heads: usize,
|
||||
window_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
#[serde(rename = "n_embd")]
|
||||
hidden_size: usize,
|
||||
#[serde(rename = "n_inner")]
|
||||
intermediate_size: usize,
|
||||
#[serde(rename = "n_layer")]
|
||||
num_hidden_layers: usize,
|
||||
#[serde(rename = "n_head")]
|
||||
num_attention_heads: usize,
|
||||
|
||||
layer_norm_epsilon: f64,
|
||||
#[serde(default = "default_rope", rename = "rotary_emb_base")]
|
||||
rope_theta: f64,
|
||||
|
||||
alt_mixer_layers: Vec<usize>,
|
||||
alt_mixer_2_layers: Vec<usize>,
|
||||
#[serde(rename = "alt_mixer")]
|
||||
la: LinearAttentionConfig,
|
||||
#[serde(rename = "alt_mixer_2")]
|
||||
swa: SlidingWindowAttentionConfig,
|
||||
}
|
||||
|
||||
fn default_rope() -> f64 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?;
|
||||
let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
// Swiglu implementation.
|
||||
// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs
|
||||
fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, D::Minus1)?;
|
||||
&xs[1].silu()? * &xs[0]
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.fc1)?;
|
||||
let xs = swiglu(&xs)?;
|
||||
let xs = xs.apply(&self.fc2)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A gated convolutional block.
|
||||
#[derive(Debug, Clone)]
|
||||
struct BasedConv {
|
||||
in_proj: Linear,
|
||||
out_proj: Linear,
|
||||
conv: Conv1d,
|
||||
state: Tensor,
|
||||
}
|
||||
|
||||
impl BasedConv {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dim = cfg.hidden_size * 2;
|
||||
|
||||
let conv1d_cfg = Conv1dConfig {
|
||||
groups: dim,
|
||||
padding: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?;
|
||||
let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?;
|
||||
let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?;
|
||||
let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?;
|
||||
Ok(Self {
|
||||
in_proj,
|
||||
out_proj,
|
||||
conv,
|
||||
state,
|
||||
})
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.state = self.state.roll(-1, D::Minus1)?;
|
||||
let (_, _, l) = self.state.dims3()?;
|
||||
self.state = self.state.narrow(D::Minus1, 0, l - 1)?;
|
||||
self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?;
|
||||
|
||||
let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)?
|
||||
.sum_keepdim(0)?
|
||||
.sum(D::Minus1)?;
|
||||
|
||||
let xs = xs.unsqueeze(1)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.in_proj)?;
|
||||
let us = xs.chunk(2, D::Minus1)?;
|
||||
let (_b, l, _d) = us[0].dims3()?;
|
||||
let u_conv = if seqlen_offset > 0 {
|
||||
self.step(&us[0])?
|
||||
} else {
|
||||
let k = std::cmp::min(3, l);
|
||||
self.state = self.state.narrow(D::Minus1, 0, 3 - k)?;
|
||||
let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?;
|
||||
self.state = Tensor::cat(&[&self.state, &xs], 2)?;
|
||||
|
||||
us[0]
|
||||
.transpose(1, 2)?
|
||||
.apply(&self.conv)?
|
||||
.narrow(D::Minus1, 0, l)?
|
||||
.transpose(1, 2)?
|
||||
};
|
||||
|
||||
let u_conv = u_conv.silu()?;
|
||||
let v = u_conv.broadcast_mul(&us[1])?;
|
||||
let xs = v.apply(&self.out_proj)?;
|
||||
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// Linear attention approximating softmax using second order Taylor polynomials.
|
||||
#[derive(Debug, Clone)]
|
||||
struct LinearAttention {
|
||||
proj_q: Linear,
|
||||
proj_k: Linear,
|
||||
proj_v: Linear,
|
||||
out_proj: Linear,
|
||||
feature_dim: usize,
|
||||
num_heads: usize,
|
||||
input_dim: usize,
|
||||
k_state: Tensor,
|
||||
kv_state: Tensor,
|
||||
}
|
||||
|
||||
impl LinearAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let input_dim = cfg.la.feature_map.input_dim;
|
||||
let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?;
|
||||
let proj_k = linear_no_bias(
|
||||
cfg.hidden_size,
|
||||
cfg.la.num_heads * cfg.la.feature_dim,
|
||||
vb.pp("proj_k"),
|
||||
)?;
|
||||
let proj_q = linear_no_bias(
|
||||
cfg.hidden_size,
|
||||
cfg.la.num_heads * cfg.la.feature_dim,
|
||||
vb.pp("proj_q"),
|
||||
)?;
|
||||
|
||||
let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?;
|
||||
let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1;
|
||||
let k_state = Tensor::zeros(
|
||||
(1, cfg.la.num_heads, 1, 1, expanded_size),
|
||||
vb.dtype(),
|
||||
vb.device(),
|
||||
)?;
|
||||
let kv_state = Tensor::zeros(
|
||||
(1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size),
|
||||
vb.dtype(),
|
||||
vb.device(),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
proj_q,
|
||||
proj_k,
|
||||
proj_v,
|
||||
out_proj,
|
||||
feature_dim: cfg.la.feature_dim,
|
||||
num_heads: cfg.la.num_heads,
|
||||
input_dim,
|
||||
k_state,
|
||||
kv_state,
|
||||
})
|
||||
}
|
||||
|
||||
fn taylor_expansion(&self) -> Result<Func<'static>> {
|
||||
let r2 = std::f64::consts::SQRT_2;
|
||||
let rd = (self.input_dim as f64).sqrt();
|
||||
let rrd = rd.sqrt();
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let dims = xs.dims();
|
||||
let mut d = dims.to_vec();
|
||||
if let Some(last) = d.last_mut() {
|
||||
*last = 1;
|
||||
};
|
||||
|
||||
let x = xs
|
||||
.unsqueeze(D::Minus1)?
|
||||
.broadcast_mul(&xs.unsqueeze(D::Minus2)?)?;
|
||||
let x = (x.flatten_from(D::Minus2)? / r2)?;
|
||||
let o = Tensor::ones(d, xs.dtype(), xs.device())?;
|
||||
let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?;
|
||||
|
||||
Ok(x)
|
||||
}))
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let eps = 1e-12;
|
||||
|
||||
let feature_map = self.taylor_expansion()?;
|
||||
|
||||
let (b, l, d) = xs.dims3()?;
|
||||
let q = xs.apply(&self.proj_q)?;
|
||||
let k = xs.apply(&self.proj_k)?;
|
||||
let v = xs.apply(&self.proj_v)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b, l, self.num_heads, self.feature_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b, l, self.num_heads, self.feature_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
.reshape((b, l, self.num_heads, d / self.num_heads))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let q = feature_map.forward(&q)?;
|
||||
let k = feature_map.forward(&k)?;
|
||||
|
||||
let y = if seqlen_offset > 0 {
|
||||
let (_b, _h, l, _d) = k.dims4()?;
|
||||
let q = q.unsqueeze(D::Minus2)?;
|
||||
let k = k.unsqueeze(D::Minus2)?;
|
||||
let v = v.unsqueeze(D::Minus1)?;
|
||||
let kn = k.narrow(D::Minus1, l - 1, 1)?;
|
||||
let vn = v.narrow(D::Minus1, l - 1, 1)?;
|
||||
|
||||
self.k_state = self.k_state.broadcast_add(&kn)?;
|
||||
self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?;
|
||||
|
||||
let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?;
|
||||
let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?;
|
||||
num.broadcast_div(&den)?
|
||||
} else {
|
||||
self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?;
|
||||
self.kv_state = k
|
||||
.transpose(2, 3)?
|
||||
.matmul(&v)?
|
||||
.transpose(2, 3)?
|
||||
.unsqueeze(2)?;
|
||||
let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
|
||||
let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?;
|
||||
let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?;
|
||||
|
||||
let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?;
|
||||
aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)?
|
||||
};
|
||||
|
||||
let (b, h, l, d) = y.dims4()?;
|
||||
let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?;
|
||||
let y = self.out_proj.forward(&y)?;
|
||||
|
||||
Ok(y)
|
||||
}
|
||||
}
|
||||
|
||||
// Rotary embeddings used in local attention.
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = 2048; // Hardcoded, missing from config.
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
// Local attention using a small sliding window.
|
||||
#[derive(Debug, Clone)]
|
||||
struct SlidingWindowAttention {
|
||||
wqkv: Linear,
|
||||
out_proj: Linear,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl SlidingWindowAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let num_heads = cfg.swa.num_heads;
|
||||
let head_dim = hidden_size / num_heads;
|
||||
let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?;
|
||||
let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
out_proj,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let qkv = xs.apply(&self.wqkv)?;
|
||||
let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?;
|
||||
|
||||
let q = qkv.i((.., .., 0))?;
|
||||
let k = qkv.i((.., .., 1))?;
|
||||
let v = qkv.i((.., .., 2))?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (q, k) = self
|
||||
.rotary_emb
|
||||
.apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &k], 2)?;
|
||||
let v = Tensor::cat(&[prev_v, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
let out = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.out_proj)?;
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
// The model layers use three types of mixers.
|
||||
#[derive(Debug, Clone)]
|
||||
enum SequenceMixer {
|
||||
Based(BasedConv),
|
||||
Linear(LinearAttention),
|
||||
Sliding(SlidingWindowAttention),
|
||||
}
|
||||
|
||||
impl SequenceMixer {
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Based(b) => b.forward(xs, pos),
|
||||
Self::Linear(b) => b.forward(xs, pos),
|
||||
Self::Sliding(b) => b.forward(xs, attention_mask, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
mlp: MLP,
|
||||
norm1: RmsNorm,
|
||||
norm2: RmsNorm,
|
||||
mixer: SequenceMixer,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?;
|
||||
let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?;
|
||||
|
||||
let l_attn = cfg.alt_mixer_layers.contains(&layer_idx);
|
||||
let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx);
|
||||
|
||||
let mixer = if l_attn {
|
||||
SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?)
|
||||
} else if sw_attn {
|
||||
SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?)
|
||||
} else {
|
||||
SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
mlp,
|
||||
norm1,
|
||||
norm2,
|
||||
mixer,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.norm1.forward(xs)?;
|
||||
let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: super::with_tracing::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?;
|
||||
let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?;
|
||||
let vb_m = vb.pp("transformer");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.swa.window_size,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let sliding_window = self.sliding_window / 2;
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
}
|
@ -419,7 +419,7 @@ struct BertEncoder {
|
||||
impl BertEncoder {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
|
||||
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(BertEncoder { layers, span })
|
||||
@ -454,8 +454,8 @@ impl BertModel {
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
|
||||
BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
|
||||
BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
@ -501,5 +501,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?)
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
}
|
||||
|
@ -298,7 +298,7 @@ impl GPTBigCode {
|
||||
let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
|
||||
let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
|
||||
.map(|i| Block::load(vb_t.pp(format!("h.{i}")), &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
|
||||
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
|
||||
|
376
candle-transformers/src/models/dac.rs
Normal file
376
candle-transformers/src/models/dac.rs
Normal file
@ -0,0 +1,376 @@
|
||||
/// Adapted from https://github.com/descriptinc/descript-audio-codec
|
||||
use crate::models::encodec;
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder};
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub num_codebooks: usize,
|
||||
pub model_bitrate: u32,
|
||||
pub codebook_size: usize,
|
||||
pub latent_dim: usize,
|
||||
pub frame_rate: u32,
|
||||
pub sampling_rate: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Snake1d {
|
||||
alpha: Tensor,
|
||||
}
|
||||
|
||||
impl Snake1d {
|
||||
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let alpha = vb.get((1, channels, 1), "alpha")?;
|
||||
Ok(Self { alpha })
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Snake1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs_shape = xs.shape();
|
||||
let xs = xs.flatten_from(2)?;
|
||||
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
|
||||
let sin = (&sin * &sin)?;
|
||||
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResidualUnit {
|
||||
snake1: Snake1d,
|
||||
conv1: Conv1d,
|
||||
snake2: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl ResidualUnit {
|
||||
pub fn new(dim: usize, dilation: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let pad = ((7 - 1) * dilation) / 2;
|
||||
let vb = vb.pp("block");
|
||||
let snake1 = Snake1d::new(dim, vb.pp(0))?;
|
||||
let cfg1 = Conv1dConfig {
|
||||
dilation,
|
||||
padding: pad,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
|
||||
let snake2 = Snake1d::new(dim, vb.pp(2))?;
|
||||
let conv2 = encodec::conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
|
||||
Ok(Self {
|
||||
snake1,
|
||||
conv1,
|
||||
snake2,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for ResidualUnit {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = xs
|
||||
.apply(&self.snake1)?
|
||||
.apply(&self.conv1)?
|
||||
.apply(&self.snake2)?
|
||||
.apply(&self.conv2)?;
|
||||
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
|
||||
if pad > 0 {
|
||||
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
|
||||
} else {
|
||||
ys + xs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncoderBlock {
|
||||
res1: ResidualUnit,
|
||||
res2: ResidualUnit,
|
||||
res3: ResidualUnit,
|
||||
snake1: Snake1d,
|
||||
conv1: Conv1d,
|
||||
}
|
||||
|
||||
impl EncoderBlock {
|
||||
pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let res1 = ResidualUnit::new(dim / 2, 1, vb.pp(0))?;
|
||||
let res2 = ResidualUnit::new(dim / 2, 3, vb.pp(1))?;
|
||||
let res3 = ResidualUnit::new(dim / 2, 9, vb.pp(2))?;
|
||||
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
||||
let cfg1 = Conv1dConfig {
|
||||
stride,
|
||||
padding: (stride + 1) / 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||
Ok(Self {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
snake1,
|
||||
conv1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for EncoderBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.res1)?
|
||||
.apply(&self.res2)?
|
||||
.apply(&self.res3)?
|
||||
.apply(&self.snake1)?
|
||||
.apply(&self.conv1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encoder {
|
||||
conv1: Conv1d,
|
||||
blocks: Vec<EncoderBlock>,
|
||||
snake1: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl candle::Module for Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.conv1)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
xs.apply(&self.snake1)?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
pub fn new(
|
||||
mut d_model: usize,
|
||||
strides: &[usize],
|
||||
d_latent: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(0))?;
|
||||
let mut blocks = Vec::with_capacity(strides.len());
|
||||
for (block_idx, stride) in strides.iter().enumerate() {
|
||||
d_model *= 2;
|
||||
let block = EncoderBlock::new(d_model, *stride, vb.pp(block_idx + 1))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let snake1 = Snake1d::new(d_model, vb.pp(strides.len() + 1))?;
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv2 =
|
||||
encodec::conv1d_weight_norm(d_model, d_latent, 3, cfg2, vb.pp(strides.len() + 2))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
blocks,
|
||||
snake1,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecoderBlock {
|
||||
snake1: Snake1d,
|
||||
conv_tr1: ConvTranspose1d,
|
||||
res1: ResidualUnit,
|
||||
res2: ResidualUnit,
|
||||
res3: ResidualUnit,
|
||||
}
|
||||
|
||||
impl DecoderBlock {
|
||||
pub fn new(in_dim: usize, out_dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||
let cfg = ConvTranspose1dConfig {
|
||||
stride,
|
||||
padding: (stride + 1) / 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
||||
in_dim,
|
||||
out_dim,
|
||||
2 * stride,
|
||||
true,
|
||||
cfg,
|
||||
vb.pp(1),
|
||||
)?;
|
||||
let res1 = ResidualUnit::new(out_dim, 1, vb.pp(2))?;
|
||||
let res2 = ResidualUnit::new(out_dim, 3, vb.pp(3))?;
|
||||
let res3 = ResidualUnit::new(out_dim, 9, vb.pp(4))?;
|
||||
Ok(Self {
|
||||
snake1,
|
||||
conv_tr1,
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle_nn::Module for DecoderBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.snake1)?
|
||||
.apply(&self.conv_tr1)?
|
||||
.apply(&self.res1)?
|
||||
.apply(&self.res2)?
|
||||
.apply(&self.res3)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Decoder {
|
||||
conv1: Conv1d,
|
||||
blocks: Vec<DecoderBlock>,
|
||||
snake1: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
mut channels: usize,
|
||||
rates: &[usize],
|
||||
d_out: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("model");
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = encodec::conv1d_weight_norm(in_c, channels, 7, cfg1, vb.pp(0))?;
|
||||
let mut blocks = Vec::with_capacity(rates.len());
|
||||
for (idx, stride) in rates.iter().enumerate() {
|
||||
let block = DecoderBlock::new(channels, channels / 2, *stride, vb.pp(idx + 1))?;
|
||||
channels /= 2;
|
||||
blocks.push(block)
|
||||
}
|
||||
let snake1 = Snake1d::new(channels, vb.pp(rates.len() + 1))?;
|
||||
let conv2 = encodec::conv1d_weight_norm(channels, d_out, 7, cfg1, vb.pp(rates.len() + 2))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
blocks,
|
||||
snake1,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Decoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.conv1)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
xs.apply(&self.snake1)?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VectorQuantizer {
|
||||
in_proj: Conv1d,
|
||||
out_proj: Conv1d,
|
||||
codebook: candle_nn::Embedding,
|
||||
}
|
||||
|
||||
impl VectorQuantizer {
|
||||
pub fn new(in_dim: usize, cb_size: usize, cb_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let in_proj =
|
||||
encodec::conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
|
||||
let out_proj =
|
||||
encodec::conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
|
||||
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
|
||||
Ok(Self {
|
||||
in_proj,
|
||||
out_proj,
|
||||
codebook,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||
embed_id.apply(&self.codebook)
|
||||
}
|
||||
|
||||
pub fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||
self.embed_code(embed_id)?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResidualVectorQuantizer {
|
||||
quantizers: Vec<VectorQuantizer>,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantizer {
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
n_codebooks: usize,
|
||||
cb_size: usize,
|
||||
cb_dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = &vb.pp("quantizers");
|
||||
let quantizers = (0..n_codebooks)
|
||||
.map(|i| VectorQuantizer::new(input_dim, cb_size, cb_dim, vb.pp(i)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { quantizers })
|
||||
}
|
||||
|
||||
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut sum = None;
|
||||
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
||||
let z_p_i = quantizer.decode_code(&codes.i((.., idx))?)?;
|
||||
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
|
||||
let s = match sum {
|
||||
None => z_q_i,
|
||||
Some(s) => (s + z_q_i)?,
|
||||
};
|
||||
sum = Some(s)
|
||||
}
|
||||
match sum {
|
||||
Some(s) => Ok(s),
|
||||
None => candle::bail!("empty codebooks"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
pub encoder: Encoder,
|
||||
pub quantizer: ResidualVectorQuantizer,
|
||||
pub decoder: Decoder,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("model");
|
||||
let encoder = Encoder::new(64, &[2, 4, 8, 8], cfg.latent_dim, vb.pp("encoder"))?;
|
||||
let quantizer = ResidualVectorQuantizer::new(
|
||||
cfg.latent_dim,
|
||||
cfg.num_codebooks,
|
||||
cfg.codebook_size,
|
||||
8,
|
||||
vb.pp("quantizer"),
|
||||
)?;
|
||||
let decoder = Decoder::new(cfg.latent_dim, 1536, &[8, 8, 4, 2], 1, vb.pp("decoder"))?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn decode_codes(&self, audio_codes: &Tensor) -> Result<Tensor> {
|
||||
let audio_values = self.quantizer.from_codes(audio_codes)?;
|
||||
audio_values.apply(&self.decoder)
|
||||
}
|
||||
}
|
@ -275,7 +275,7 @@ struct Transformer {
|
||||
impl Transformer {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.n_layers)
|
||||
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
|
||||
.map(|index| TransformerBlock::load(vb.pp(format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(Transformer { layers, span })
|
||||
@ -311,8 +311,8 @@ impl DistilBertModel {
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
|
||||
Embeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
|
||||
Transformer::load(vb.pp(format!("{model_type}.transformer")), config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
|
@ -136,7 +136,7 @@ pub fn conv1d_weight_norm(
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn conv_transpose1d_weight_norm(
|
||||
pub fn conv_transpose1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
|
@ -448,7 +448,7 @@ impl Falcon {
|
||||
vb.pp("transformer.word_embeddings"),
|
||||
)?;
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
|
||||
.map(|i| FalconDecoderLayer::load(vb.pp(format!("transformer.h.{i}")), &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(
|
||||
cfg.hidden_size,
|
||||
|
512
candle-transformers/src/models/fastvit.rs
Normal file
512
candle-transformers/src/models/fastvit.rs
Normal file
@ -0,0 +1,512 @@
|
||||
//! FastViT inference implementation based on timm
|
||||
//!
|
||||
//! See "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"
|
||||
//! https://arxiv.org/pdf/2303.14189
|
||||
//!
|
||||
//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
|
||||
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
exp_ratio: usize,
|
||||
in_channels: usize,
|
||||
blocks: [usize; 4],
|
||||
attn: bool,
|
||||
lkc_use_act: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn t8() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 48,
|
||||
blocks: [2, 2, 4, 2],
|
||||
attn: false,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn t12() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 64,
|
||||
blocks: [2, 2, 6, 2],
|
||||
attn: false,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn s12() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [2, 2, 6, 2],
|
||||
attn: false,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn sa12() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [2, 2, 6, 2],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn sa24() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [4, 4, 12, 4],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn sa36() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 64,
|
||||
blocks: [6, 6, 18, 6],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
pub fn ma36() -> Self {
|
||||
Self {
|
||||
exp_ratio: 4,
|
||||
in_channels: 76,
|
||||
blocks: [6, 6, 18, 6],
|
||||
attn: true,
|
||||
lkc_use_act: false,
|
||||
}
|
||||
}
|
||||
|
||||
// configs used by MobileCLIP's image encoder
|
||||
pub fn mci0() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 64,
|
||||
blocks: [2, 6, 10, 2],
|
||||
attn: true,
|
||||
lkc_use_act: true,
|
||||
}
|
||||
}
|
||||
pub fn mci1() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 64,
|
||||
blocks: [4, 12, 20, 4],
|
||||
attn: true,
|
||||
lkc_use_act: true,
|
||||
}
|
||||
}
|
||||
pub fn mci2() -> Self {
|
||||
Self {
|
||||
exp_ratio: 3,
|
||||
in_channels: 80,
|
||||
blocks: [4, 12, 24, 4],
|
||||
attn: true,
|
||||
lkc_use_act: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn conv_norm(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding: kernel / 2,
|
||||
groups: in_channels,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
|
||||
let conv = conv.absorb_bn(&bn)?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&conv)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn conv_mlp(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conv = conv_norm(dim, dim, 7, 1, vb.pp("conv"))?;
|
||||
let fc1 = conv2d(dim, dim * exp_ratio, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(dim * exp_ratio, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&conv)?.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn squeeze_and_excitation(
|
||||
in_channels: usize,
|
||||
squeeze_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
|
||||
|
||||
residual.broadcast_mul(&xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
let bias = (beta - mu * &gps)?;
|
||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
||||
|
||||
Ok((weights, bias))
|
||||
}
|
||||
|
||||
fn mobileone_block(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
group_size: usize,
|
||||
use_act: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let groups = if group_size == 0 {
|
||||
1
|
||||
} else {
|
||||
in_channels / group_size
|
||||
};
|
||||
|
||||
let padding = kernel / 2;
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
groups,
|
||||
padding,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut w = Tensor::zeros(
|
||||
(out_channels, in_channels / groups, kernel, kernel),
|
||||
DType::F32,
|
||||
vb.device(),
|
||||
)?;
|
||||
let dim = out_channels;
|
||||
|
||||
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
|
||||
|
||||
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.0.bn"));
|
||||
let conv_kxk = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_kxk.0.conv"),
|
||||
);
|
||||
|
||||
if let (Ok(conv), Ok(bn)) = (conv_kxk, conv_kxk_bn) {
|
||||
let (wk, bk) = fuse_conv_bn(conv.weight(), bn)?;
|
||||
w = (w + wk)?;
|
||||
b = (b + bk)?;
|
||||
};
|
||||
|
||||
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"));
|
||||
let conv_scale = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv2d_cfg,
|
||||
vb.pp("conv_scale.conv"),
|
||||
);
|
||||
|
||||
if let (Ok(conv), Ok(bn)) = (conv_scale, conv_scale_bn) {
|
||||
let (ws, bs) = fuse_conv_bn(conv.weight(), bn)?;
|
||||
// pad to 3x3
|
||||
let ws = ws
|
||||
.pad_with_zeros(D::Minus1, 1, 1)?
|
||||
.pad_with_zeros(D::Minus2, 1, 1)?;
|
||||
|
||||
w = (w + ws)?;
|
||||
b = (b + bs)?;
|
||||
};
|
||||
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("se"));
|
||||
|
||||
// read and reparameterize the identity bn into wi and bi
|
||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"));
|
||||
|
||||
if let Ok(id_bn) = identity_bn {
|
||||
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
|
||||
let id = in_channels / groups;
|
||||
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
|
||||
for i in 0..in_channels {
|
||||
if kernel > 1 {
|
||||
weights[i * kernel * kernel + 4] = 1.0;
|
||||
} else {
|
||||
weights[i * (id + 1)] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
||||
let (wi, bi) = fuse_conv_bn(weights, id_bn)?;
|
||||
|
||||
w = (w + wi)?;
|
||||
b = (b + bi)?;
|
||||
};
|
||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.apply(&reparam_conv)?;
|
||||
if let Ok(f) = &se {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
if use_act {
|
||||
xs = xs.gelu_erf()?;
|
||||
};
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn repmixer(dim: usize, kernel: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
|
||||
let norm = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("norm"))?;
|
||||
let mixer = mobileone_block(dim, dim, kernel, 1, 1, false, vb.pp("mixer"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs.clone();
|
||||
let xs = (xs.apply(&mixer)? - xs.apply(&norm)?)?;
|
||||
let xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
|
||||
let xs = (xs + residual)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn repmixer_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let gamma = vb.get((dim, 1, 1), "layer_scale.gamma")?;
|
||||
let token_mixer = repmixer(dim, 3, vb.pp("token_mixer"))?;
|
||||
let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs.apply(&token_mixer)?;
|
||||
let mut xs = residual.apply(&mlp)?;
|
||||
xs = xs.broadcast_mul(&gamma.reshape((1, (), 1, 1))?)?;
|
||||
let xs = (xs + residual)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn positional_encoding(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 3,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conv = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("pos_enc"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = (xs + xs.apply(&conv)?)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn attention(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?;
|
||||
let proj = linear(dim, dim, vb.pp("proj"))?;
|
||||
let head_dim = 32;
|
||||
let num_heads = dim / head_dim;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.clone();
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
let n = h * w;
|
||||
let xs = xs.flatten_from(2)?.transpose(D::Minus1, D::Minus2)?;
|
||||
let qkv = xs
|
||||
.apply(&qkv)?
|
||||
.reshape((b, n, 3, num_heads, head_dim))?
|
||||
.permute((2, 0, 3, 1, 4))?;
|
||||
|
||||
let q = qkv.get(0)?;
|
||||
let k = qkv.get(1)?;
|
||||
let v = qkv.get(2)?;
|
||||
|
||||
let q = (q * scale)?;
|
||||
|
||||
let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
|
||||
let att = softmax(&att, D::Minus1)?;
|
||||
let xs = att.matmul(&v)?;
|
||||
|
||||
let xs = xs.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
let xs = xs.apply(&proj)?;
|
||||
let xs = xs.transpose(D::Minus1, D::Minus2)?.reshape((b, c, h, w))?;
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn attention_block(dim: usize, exp_ratio: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let gamma1 = vb.get((dim, 1, 1), "layer_scale_1.gamma")?;
|
||||
let gamma2 = vb.get((dim, 1, 1), "layer_scale_2.gamma")?;
|
||||
let norm = batch_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||
let token_mixer = attention(dim, vb.pp("token_mixer"))?;
|
||||
let mlp = conv_mlp(dim, exp_ratio, vb.pp("mlp"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.clone();
|
||||
let xs = (&xs
|
||||
+ &xs
|
||||
.apply_t(&norm, false)?
|
||||
.apply(&token_mixer)?
|
||||
.broadcast_mul(&gamma1.reshape((1, (), 1, 1))?)?)?;
|
||||
|
||||
let xs = (&xs
|
||||
+ &xs
|
||||
.apply(&mlp)?
|
||||
.broadcast_mul(&gamma2.reshape((1, (), 1, 1))?)?)?;
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn fastvit_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nblocks = cfg.blocks[idx];
|
||||
let mut blocks = Vec::with_capacity(nblocks);
|
||||
|
||||
let dim = cfg.in_channels << idx;
|
||||
let downsample = fastvit_patch_embed(dim / 2, dim, cfg.lkc_use_act, vb.pp("downsample"));
|
||||
for block_idx in 0..nblocks {
|
||||
let block = if cfg.attn && idx == 3 {
|
||||
attention_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
|
||||
} else {
|
||||
repmixer_block(dim, cfg.exp_ratio, vb.pp(format!("blocks.{block_idx}")))?
|
||||
};
|
||||
blocks.push(block);
|
||||
}
|
||||
let pos_emb = positional_encoding(dim, vb.pp("pos_emb"));
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
if let Ok(ds) = &downsample {
|
||||
xs = xs.apply(ds)?;
|
||||
}
|
||||
if let Ok(pos) = &pos_emb {
|
||||
xs = xs.apply(pos)?;
|
||||
}
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn fastvit_patch_embed(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
use_act: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?;
|
||||
let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?;
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se"));
|
||||
let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = (xs.apply(&lk)? + xs.apply(&sk)?)?;
|
||||
if let Ok(f) = &se {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
if use_act {
|
||||
xs = xs.gelu_erf()?;
|
||||
};
|
||||
let xs = xs.apply(&mb)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn fastvit_stem(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let mb0 = mobileone_block(in_channels, out_channels, 3, 2, 0, true, vb.pp(0))?;
|
||||
let mb1 = mobileone_block(out_channels, out_channels, 3, 2, 1, true, vb.pp(1))?;
|
||||
let mb2 = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp(2))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&mb0)?.apply(&mb1)?.apply(&mb2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a fastvit model for a given configuration.
|
||||
fn fastvit_model(cfg: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let linear = linear(cfg.in_channels * 16, nclasses, vb.pp("head.fc"))?;
|
||||
Some(linear)
|
||||
}
|
||||
};
|
||||
|
||||
let stem = fastvit_stem(3, cfg.in_channels, vb.pp("stem"))?;
|
||||
let final_conv = mobileone_block(
|
||||
cfg.in_channels * 8,
|
||||
cfg.in_channels * 16,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
true,
|
||||
vb.pp("final_conv"),
|
||||
)?;
|
||||
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = fastvit_stage(cfg, 0, vb.pp(0))?;
|
||||
let stage2 = fastvit_stage(cfg, 1, vb.pp(1))?;
|
||||
let stage3 = fastvit_stage(cfg, 2, vb.pp(2))?;
|
||||
let stage4 = fastvit_stage(cfg, 3, vb.pp(3))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.apply(&stage4)?
|
||||
.apply(&final_conv)?;
|
||||
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn fastvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
fastvit_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn fastvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
fastvit_model(cfg, None, vb)
|
||||
}
|
449
candle-transformers/src/models/gemma2.rs
Normal file
449
candle-transformers/src/models/gemma2.rs
Normal file
@ -0,0 +1,449 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
|
||||
|
||||
fn default_max_position_embeddings() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub attention_bias: bool,
|
||||
pub head_dim: usize,
|
||||
pub hidden_activation: Activation,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub vocab_size: usize,
|
||||
pub final_logit_softcapping: Option<f64>,
|
||||
pub attn_logit_softcapping: Option<f64>,
|
||||
pub query_pre_attn_scalar: usize,
|
||||
// TODO: Handle the sliding window in the attention mask.
|
||||
pub sliding_window: Option<usize>,
|
||||
|
||||
#[serde(default = "default_max_position_embeddings")]
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(dim, "weight")?;
|
||||
Ok(Self { weight, eps })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&(&self.weight + 1.0)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
attn_logit_softcapping: Option<f64>,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = cfg.head_dim;
|
||||
let bias = cfg.attention_bias;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
attn_logit_softcapping: cfg.attn_logit_softcapping,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = query_states.transpose(1, 2)?;
|
||||
let k = key_states.transpose(1, 2)?;
|
||||
let v = value_states.transpose(1, 2)?;
|
||||
let scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
|
||||
} else {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match self.attn_logit_softcapping {
|
||||
None => attn_weights,
|
||||
Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
|
||||
};
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, ()))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: RmsNorm,
|
||||
pre_feedforward_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let pre_feedforward_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("pre_feedforward_layernorm"),
|
||||
)?;
|
||||
let post_feedforward_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_feedforward_layernorm"),
|
||||
)?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
pre_feedforward_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.pre_feedforward_layernorm)?;
|
||||
let xs = xs.apply(&self.mlp)?;
|
||||
let xs = xs.apply(&self.post_feedforward_layernorm)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
final_logit_softcapping: Option<f64>,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
hidden_size: usize,
|
||||
sliding_window: Option<usize>,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer =
|
||||
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
final_logit_softcapping: cfg.final_logit_softcapping,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
hidden_size: cfg.hidden_size,
|
||||
sliding_window: cfg.sliding_window,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = match self.sliding_window {
|
||||
None => (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect(),
|
||||
Some(sliding_window) => (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)?;
|
||||
let logits = match self.final_logit_softcapping {
|
||||
None => logits,
|
||||
Some(sc) => ((logits / sc)?.tanh()? * sc)?,
|
||||
};
|
||||
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
595
candle-transformers/src/models/glm4.rs
Normal file
595
candle-transformers/src/models/glm4.rs
Normal file
@ -0,0 +1,595 @@
|
||||
use crate::models::with_tracing::{linear_b as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub num_layers: usize,
|
||||
pub padded_vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub ffn_hidden_size: usize,
|
||||
pub kv_channels: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub seq_length: usize,
|
||||
pub layernorm_epsilon: f64,
|
||||
pub rmsnorm: bool,
|
||||
pub apply_residual_connection_post_layernorm: bool,
|
||||
pub post_layer_norm: bool,
|
||||
pub add_bias_linear: bool,
|
||||
pub add_qkv_bias: bool,
|
||||
pub bias_dropout_fusion: bool,
|
||||
pub multi_query_attention: bool,
|
||||
pub multi_query_group_num: usize,
|
||||
pub apply_query_key_layer_scaling: bool,
|
||||
pub attention_softmax_in_fp32: bool,
|
||||
pub fp32_residual_connection: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn glm4() -> Self {
|
||||
Self {
|
||||
num_layers: 40,
|
||||
padded_vocab_size: 151552,
|
||||
hidden_size: 4096,
|
||||
ffn_hidden_size: 13696,
|
||||
kv_channels: 128,
|
||||
num_attention_heads: 32,
|
||||
seq_length: 8192,
|
||||
layernorm_epsilon: 1e-5,
|
||||
rmsnorm: true,
|
||||
apply_residual_connection_post_layernorm: false,
|
||||
post_layer_norm: true,
|
||||
add_bias_linear: false,
|
||||
add_qkv_bias: true,
|
||||
bias_dropout_fusion: true,
|
||||
multi_query_attention: true,
|
||||
multi_query_group_num: 2,
|
||||
apply_query_key_layer_scaling: true,
|
||||
attention_softmax_in_fp32: true,
|
||||
fp32_residual_connection: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
cache: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let rotary_dim = cfg.kv_channels;
|
||||
let n_elem = rotary_dim / 2;
|
||||
let inv_freq: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((cfg.seq_length, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
|
||||
Ok(Self { cache })
|
||||
}
|
||||
|
||||
fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (seqlen, _b, np, _hn) = xs.dims4()?;
|
||||
let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
|
||||
let rot_dim = cache.dim(D::Minus2)? * 2;
|
||||
let (xs, xs_pass) = (
|
||||
xs.narrow(D::Minus1, 0, rot_dim)?,
|
||||
xs.narrow(D::Minus1, rot_dim, rot_dim)?,
|
||||
);
|
||||
let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
|
||||
let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
|
||||
let (xshaped0, xshaped1) = (
|
||||
xshaped.i((.., .., .., .., 0))?,
|
||||
xshaped.i((.., .., .., .., 1))?,
|
||||
);
|
||||
let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
|
||||
let xs_out = Tensor::stack(
|
||||
&[
|
||||
(xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
|
||||
(xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let xs_out = xs_out.flatten_from(3)?;
|
||||
Tensor::cat(&[xs_out, xs_pass], D::Minus1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CoreAttention {
|
||||
coeff: Option<f64>,
|
||||
norm_factor: f64,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32, dtype: DType) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true.to_dtype(dtype)?, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl CoreAttention {
|
||||
fn new(layer_number: usize, cfg: &Config, dtype: DType) -> Result<Self> {
|
||||
let norm_factor = (cfg.kv_channels as f64).sqrt();
|
||||
let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
|
||||
let coeff = f64::max(1.0, layer_number as f64);
|
||||
(norm_factor * coeff, Some(coeff))
|
||||
} else {
|
||||
(norm_factor, None)
|
||||
};
|
||||
Ok(Self {
|
||||
coeff,
|
||||
norm_factor,
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
query_layer: &Tensor,
|
||||
key_layer: &Tensor,
|
||||
value_layer: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let output_size = (
|
||||
query_layer.dim(1)?, // b
|
||||
query_layer.dim(2)?, // np
|
||||
query_layer.dim(0)?, // sq
|
||||
key_layer.dim(0)?, // sk
|
||||
);
|
||||
let query_layer =
|
||||
query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
|
||||
let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
|
||||
let matmul_result = Tensor::matmul(
|
||||
&query_layer.transpose(0, 1)?.contiguous()?,
|
||||
&key_layer.transpose(0, 1)?.transpose(1, 2)?.contiguous()?,
|
||||
)?;
|
||||
let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
|
||||
let matmul_result = match self.coeff {
|
||||
None => matmul_result,
|
||||
Some(coeff) => (matmul_result * coeff)?,
|
||||
};
|
||||
let attention_scores = match attention_mask {
|
||||
Some(mask) => masked_fill(
|
||||
&matmul_result,
|
||||
&mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
|
||||
f32::NEG_INFINITY,
|
||||
self.dtype,
|
||||
)?,
|
||||
None => matmul_result,
|
||||
};
|
||||
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
|
||||
let output_size = (
|
||||
value_layer.dim(1)?,
|
||||
value_layer.dim(2)?,
|
||||
query_layer.dim(0)?,
|
||||
value_layer.dim(3)?,
|
||||
);
|
||||
let value_layer =
|
||||
value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
|
||||
let attention_probs =
|
||||
attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
|
||||
let context_layer = Tensor::matmul(
|
||||
&attention_probs.contiguous()?,
|
||||
&value_layer.transpose(0, 1)?.contiguous()?,
|
||||
)?;
|
||||
let context_layer = context_layer.reshape(output_size)?;
|
||||
let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
|
||||
context_layer.flatten_from(D::Minus2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SelfAttention {
|
||||
query_key_value: Linear,
|
||||
core_attention: CoreAttention,
|
||||
dense: Linear,
|
||||
multi_query_attention: bool,
|
||||
num_attention_heads_per_partition: usize,
|
||||
num_multi_query_groups_per_partition: usize,
|
||||
hidden_size_per_attention_head: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let projection_size = cfg.kv_channels * cfg.num_attention_heads;
|
||||
let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
|
||||
let qkv_hidden_size = if cfg.multi_query_attention {
|
||||
projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
|
||||
} else {
|
||||
3 * projection_size
|
||||
};
|
||||
let query_key_value = linear(
|
||||
cfg.hidden_size,
|
||||
qkv_hidden_size,
|
||||
cfg.add_bias_linear || cfg.add_qkv_bias,
|
||||
vb.pp("query_key_value"),
|
||||
)?;
|
||||
let core_attention = CoreAttention::new(layer_number, cfg, vb.dtype())?;
|
||||
let dense = linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size,
|
||||
cfg.add_bias_linear,
|
||||
vb.pp("dense"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
query_key_value,
|
||||
core_attention,
|
||||
dense,
|
||||
multi_query_attention: cfg.multi_query_attention,
|
||||
num_attention_heads_per_partition: cfg.num_attention_heads,
|
||||
num_multi_query_groups_per_partition: cfg.multi_query_group_num,
|
||||
hidden_size_per_attention_head: cfg.kv_channels,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
rotary_emb: &RotaryEmbedding,
|
||||
) -> Result<Tensor> {
|
||||
let mixed_x_layer = xs.apply(&self.query_key_value)?;
|
||||
if !self.multi_query_attention {
|
||||
candle::bail!("only multi_query_attention=true is supported")
|
||||
}
|
||||
let hpa = self.hidden_size_per_attention_head;
|
||||
let query_layer =
|
||||
mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
|
||||
let key_layer = mixed_x_layer.narrow(
|
||||
D::Minus1,
|
||||
self.num_attention_heads_per_partition * hpa,
|
||||
self.num_multi_query_groups_per_partition * hpa,
|
||||
)?;
|
||||
let value_layer = mixed_x_layer.narrow(
|
||||
D::Minus1,
|
||||
self.num_attention_heads_per_partition * hpa
|
||||
+ self.num_multi_query_groups_per_partition * hpa,
|
||||
self.num_multi_query_groups_per_partition * hpa,
|
||||
)?;
|
||||
let query_layer = query_layer.reshape((
|
||||
query_layer.dim(0)?,
|
||||
query_layer.dim(1)?,
|
||||
self.num_attention_heads_per_partition,
|
||||
hpa,
|
||||
))?;
|
||||
let key_layer = key_layer.reshape((
|
||||
key_layer.dim(0)?,
|
||||
key_layer.dim(1)?,
|
||||
self.num_multi_query_groups_per_partition,
|
||||
hpa,
|
||||
))?;
|
||||
let value_layer = value_layer.reshape((
|
||||
value_layer.dim(0)?,
|
||||
value_layer.dim(1)?,
|
||||
self.num_multi_query_groups_per_partition,
|
||||
hpa,
|
||||
))?;
|
||||
|
||||
// Rotary embeddings.
|
||||
let seqlen_offset = match &self.kv_cache {
|
||||
None => 0,
|
||||
Some((prev_k, _)) => prev_k.dim(0)?,
|
||||
};
|
||||
let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
|
||||
let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
|
||||
|
||||
// KV cache.
|
||||
let (key_layer, value_layer) = match &self.kv_cache {
|
||||
None => (key_layer, value_layer),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
|
||||
let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
|
||||
|
||||
// Repeat KV.
|
||||
let ratio =
|
||||
self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
|
||||
let key_layer = {
|
||||
let (d0, d1, d2, d3) = key_layer.dims4()?;
|
||||
key_layer
|
||||
.unsqueeze(D::Minus2)?
|
||||
.expand((d0, d1, d2, ratio, d3))?
|
||||
.reshape((
|
||||
d0,
|
||||
d1,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))?
|
||||
};
|
||||
let value_layer = {
|
||||
let (d0, d1, d2, d3) = value_layer.dims4()?;
|
||||
value_layer
|
||||
.unsqueeze(D::Minus2)?
|
||||
.expand((d0, d1, d2, ratio, d3))?
|
||||
.reshape((
|
||||
d0,
|
||||
d1,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
))?
|
||||
};
|
||||
|
||||
let context_layer =
|
||||
self.core_attention
|
||||
.forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
|
||||
let output = context_layer.apply(&self.dense)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct MLP {
|
||||
dense_h_to_4h: Linear,
|
||||
dense_4h_to_h: Linear,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense_h_to_4h = linear(
|
||||
cfg.hidden_size,
|
||||
cfg.ffn_hidden_size * 2,
|
||||
cfg.add_bias_linear,
|
||||
vb.pp("dense_h_to_4h"),
|
||||
)?;
|
||||
let dense_4h_to_h = linear(
|
||||
cfg.ffn_hidden_size,
|
||||
cfg.hidden_size,
|
||||
cfg.add_bias_linear,
|
||||
vb.pp("dense_4h_to_h"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
dense_4h_to_h,
|
||||
dense_h_to_4h,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.dense_h_to_4h)?
|
||||
.apply(&candle_nn::Activation::Swiglu)?
|
||||
.apply(&self.dense_4h_to_h)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
input_layernorm: candle_nn::LayerNorm,
|
||||
self_attention: SelfAttention,
|
||||
post_attention_layernorm: candle_nn::LayerNorm,
|
||||
mlp: MLP,
|
||||
apply_residual_connection_post_layernorm: bool,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let input_layernorm = if cfg.rmsnorm {
|
||||
candle_nn::rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("input_layernorm"),
|
||||
)?
|
||||
.into_inner()
|
||||
} else {
|
||||
candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("input_layernorm"),
|
||||
)?
|
||||
};
|
||||
let post_attention_layernorm = if cfg.rmsnorm {
|
||||
candle_nn::rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?
|
||||
.into_inner()
|
||||
} else {
|
||||
candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?
|
||||
};
|
||||
let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
input_layernorm,
|
||||
self_attention,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attention.reset_kv_cache()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
rotary_emb: &RotaryEmbedding,
|
||||
) -> Result<Tensor> {
|
||||
let layernorm_output = xs.apply(&self.input_layernorm)?;
|
||||
let attention_output =
|
||||
self.self_attention
|
||||
.forward(&layernorm_output, attention_mask, rotary_emb)?;
|
||||
let residual = if self.apply_residual_connection_post_layernorm {
|
||||
&layernorm_output
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let layernorm_input = (residual + attention_output)?;
|
||||
let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
|
||||
let mlp_output = layernorm_output.apply(&self.mlp)?;
|
||||
let residual = if self.apply_residual_connection_post_layernorm {
|
||||
&layernorm_output
|
||||
} else {
|
||||
&layernorm_input
|
||||
};
|
||||
mlp_output + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Transformer {
|
||||
layers: Vec<Block>,
|
||||
final_layernorm: Option<candle_nn::LayerNorm>,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_l = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_layers);
|
||||
for layer_index in 0..cfg.num_layers {
|
||||
let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
|
||||
layers.push(block)
|
||||
}
|
||||
let final_layernorm = if cfg.post_layer_norm {
|
||||
let ln = if cfg.rmsnorm {
|
||||
candle_nn::rms_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("final_layernorm"),
|
||||
)?
|
||||
.into_inner()
|
||||
} else {
|
||||
candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layernorm_epsilon,
|
||||
vb.pp("final_layernorm"),
|
||||
)?
|
||||
};
|
||||
Some(ln)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
final_layernorm,
|
||||
rotary_emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
for block in self.layers.iter_mut() {
|
||||
block.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for block in self.layers.iter_mut() {
|
||||
xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
|
||||
}
|
||||
match self.final_layernorm.as_ref() {
|
||||
None => Ok(xs),
|
||||
Some(ln) => xs.apply(ln),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Embedding {
|
||||
word_embeddings: candle_nn::Embedding,
|
||||
fp32_residual_connection: bool,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let word_embeddings = candle_nn::embedding(
|
||||
cfg.padded_vocab_size,
|
||||
cfg.hidden_size,
|
||||
vb.pp("word_embeddings"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
fp32_residual_connection: cfg.fp32_residual_connection,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Embedding {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
|
||||
if self.fp32_residual_connection {
|
||||
xs.to_dtype(candle::DType::F32)
|
||||
} else {
|
||||
xs.contiguous()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embedding: Embedding,
|
||||
encoder: Transformer,
|
||||
output_layer: Linear,
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("transformer");
|
||||
let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
|
||||
let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
|
||||
let output_layer = linear(
|
||||
cfg.hidden_size,
|
||||
cfg.padded_vocab_size,
|
||||
false,
|
||||
vb.pp("output_layer"),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
embedding,
|
||||
encoder,
|
||||
output_layer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.encoder.reset_kv_cache()
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let input_embeds = xs.apply(&self.embedding)?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
|
||||
let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
|
||||
Ok(lm_logits)
|
||||
}
|
||||
}
|
458
candle-transformers/src/models/granite.rs
Normal file
458
candle-transformers/src/models/granite.rs
Normal file
@ -0,0 +1,458 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::{collections::HashMap, f32::consts::PI};
|
||||
|
||||
pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||
pub enum GraniteRopeType {
|
||||
#[serde(rename = "granite")]
|
||||
Granite,
|
||||
#[default]
|
||||
#[serde(rename = "default")]
|
||||
Default,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||
pub struct GraniteRopeConfig {
|
||||
pub factor: f32,
|
||||
pub low_freq_factor: f32,
|
||||
pub high_freq_factor: f32,
|
||||
pub original_max_position_embeddings: usize,
|
||||
pub rope_type: GraniteRopeType,
|
||||
}
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum GraniteEosToks {
|
||||
Single(u32),
|
||||
Multiple(Vec<u32>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct GraniteConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<GraniteEosToks>,
|
||||
pub rope_scaling: Option<GraniteRopeConfig>,
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl GraniteConfig {
|
||||
pub fn num_key_value_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl GraniteConfig {
|
||||
pub fn into_config(self, use_flash_attn: bool) -> Config {
|
||||
Config {
|
||||
hidden_size: self.hidden_size,
|
||||
intermediate_size: self.intermediate_size,
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads(),
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
bos_token_id: self.bos_token_id,
|
||||
eos_token_id: self.eos_token_id,
|
||||
rope_scaling: self.rope_scaling,
|
||||
max_position_embeddings: self.max_position_embeddings,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<GraniteEosToks>,
|
||||
pub rope_scaling: Option<GraniteRopeConfig>,
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
masks: HashMap<usize, Tensor>,
|
||||
pub use_kv_cache: bool,
|
||||
kvs: Vec<Option<(Tensor, Tensor)>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
|
||||
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
(0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let theta = match &config.rope_scaling {
|
||||
None
|
||||
| Some(GraniteRopeConfig {
|
||||
rope_type: GraniteRopeType::Default,
|
||||
..
|
||||
}) => calculate_default_inv_freq(config),
|
||||
Some(rope_scaling) => {
|
||||
let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
|
||||
/ rope_scaling.low_freq_factor;
|
||||
let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
|
||||
/ rope_scaling.high_freq_factor;
|
||||
|
||||
calculate_default_inv_freq(config)
|
||||
.into_iter()
|
||||
.map(|freq| {
|
||||
let wavelen = 2. * PI / freq;
|
||||
if wavelen < high_freq_wavelen {
|
||||
freq
|
||||
} else if wavelen > low_freq_wavelen {
|
||||
freq / rope_scaling.factor
|
||||
} else {
|
||||
let smooth = (rope_scaling.original_max_position_embeddings as f32
|
||||
/ wavelen
|
||||
- rope_scaling.low_freq_factor)
|
||||
/ (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
|
||||
(1. - smooth) * freq / rope_scaling.factor + smooth * freq
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
};
|
||||
|
||||
let theta = Tensor::new(theta, device)?;
|
||||
|
||||
let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((config.max_position_embeddings, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
masks: HashMap::new(),
|
||||
use_kv_cache,
|
||||
kvs: vec![None; config.num_hidden_layers],
|
||||
device: device.clone(),
|
||||
cos,
|
||||
sin,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
head_dim: usize,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
|
||||
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||
|
||||
if cache.use_kv_cache {
|
||||
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
if k_seq_len > self.max_position_embeddings {
|
||||
k = k
|
||||
.narrow(
|
||||
D::Minus1,
|
||||
k_seq_len - self.max_position_embeddings,
|
||||
self.max_position_embeddings,
|
||||
)?
|
||||
.contiguous()?
|
||||
}
|
||||
let v_seq_len = v.dims()[1];
|
||||
if v_seq_len > 2 * self.max_position_embeddings {
|
||||
v = v
|
||||
.narrow(
|
||||
D::Minus1,
|
||||
v_seq_len - self.max_position_embeddings,
|
||||
self.max_position_embeddings,
|
||||
)?
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let y = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = if seq_len == 1 {
|
||||
att
|
||||
} else {
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, f32::NEG_INFINITY)?
|
||||
};
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||
};
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let size_in = cfg.hidden_size;
|
||||
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
|
||||
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
|
||||
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
num_key_value_heads: cfg.num_key_value_heads,
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
span_rot,
|
||||
max_position_embeddings: cfg.max_position_embeddings,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
c_proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
let h_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
||||
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let rms_1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let rms_2 = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Granite {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Granite {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
}
|
@ -344,7 +344,7 @@ impl BertEncoder {
|
||||
candle::bail!("only alibi is supported as a position-embedding-type")
|
||||
}
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg))
|
||||
.map(|index| BertLayer::new(vb.pp(format!("layer.{index}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;
|
||||
|
@ -507,7 +507,7 @@ impl Llama {
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cfg).unwrap())
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
|
@ -354,7 +354,7 @@ impl Llama {
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
wte,
|
||||
|
670
candle-transformers/src/models/mimi/conv.rs
Normal file
670
candle-transformers/src/models/mimi/conv.rs
Normal file
@ -0,0 +1,670 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{Module, Result, StreamTensor, StreamingModule, Tensor, D};
|
||||
use candle_nn::{Conv1d, VarBuilder};
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Norm {
|
||||
WeightNorm,
|
||||
SpectralNorm,
|
||||
TimeGroupNorm,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum PadMode {
|
||||
Constant,
|
||||
Reflect,
|
||||
Replicate,
|
||||
}
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
bias: bool,
|
||||
config: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = if vb.contains_tensor("weight") {
|
||||
vb.get((out_c, in_c, kernel_size), "weight")?
|
||||
} else {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
|
||||
};
|
||||
let bias = if bias {
|
||||
Some(vb.get(out_c, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Conv1d::new(weight, bias, config))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NormConv1d {
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl NormConv1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
bias: bool,
|
||||
cfg: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let conv = match norm {
|
||||
None | Some(Norm::TimeGroupNorm) => {
|
||||
if bias {
|
||||
candle_nn::conv1d(in_c, out_c, k_size, cfg, vb.pp("conv"))?
|
||||
} else {
|
||||
candle_nn::conv1d_no_bias(in_c, out_c, k_size, cfg, vb.pp("conv"))?
|
||||
}
|
||||
}
|
||||
Some(Norm::WeightNorm) => {
|
||||
conv1d_weight_norm(in_c, out_c, k_size, bias, cfg, vb.pp("conv"))?
|
||||
}
|
||||
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
|
||||
};
|
||||
let norm = match norm {
|
||||
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
|
||||
Some(Norm::TimeGroupNorm) => {
|
||||
if causal {
|
||||
candle::bail!("GroupNorm doesn't support causal evaluation.")
|
||||
}
|
||||
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(norm)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
conv,
|
||||
norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "norm-conv1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for NormConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = xs.apply(&self.conv)?;
|
||||
match self.norm.as_ref() {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NormConvTranspose1d {
|
||||
ws: Tensor,
|
||||
bs: Option<Tensor>,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl NormConvTranspose1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
bias: bool,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("conv");
|
||||
let bs = if bias {
|
||||
Some(vb.get(out_c, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let ws = match norm {
|
||||
None | Some(Norm::TimeGroupNorm) => vb.get((in_c, out_c / groups, k_size), "weight")?,
|
||||
Some(Norm::WeightNorm) => {
|
||||
if vb.contains_tensor("weight") {
|
||||
vb.get((in_c, out_c, k_size), "weight")?
|
||||
} else {
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, k_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
|
||||
}
|
||||
}
|
||||
Some(Norm::SpectralNorm) => candle::bail!("SpectralNorm is not supported yet."),
|
||||
};
|
||||
let (ws, groups) = if groups == out_c && in_c == out_c {
|
||||
let eye = Tensor::eye(out_c, ws.dtype(), ws.device())?;
|
||||
let ws = ws
|
||||
.repeat((1, out_c, 1))?
|
||||
.mul(&eye.unsqueeze(2)?.repeat((1, 1, k_size))?)?;
|
||||
(ws, 1)
|
||||
} else {
|
||||
(ws, groups)
|
||||
};
|
||||
let norm = match norm {
|
||||
None | Some(Norm::WeightNorm) | Some(Norm::SpectralNorm) => None,
|
||||
Some(Norm::TimeGroupNorm) => {
|
||||
if causal {
|
||||
candle::bail!("GroupNorm doesn't support causal evaluation.")
|
||||
}
|
||||
let norm = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(norm)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
ws,
|
||||
bs,
|
||||
k_size,
|
||||
stride,
|
||||
groups,
|
||||
norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "norm-conv-tr1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for NormConvTranspose1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
// conv-transpose1d seems to be broken on metal after enough iterations. Causing
|
||||
// the following error:
|
||||
// _status < MTLCommandBufferStatusCommitted >
|
||||
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
|
||||
// This is now fixed in candle.
|
||||
let xs = Tensor::conv_transpose1d(xs, &self.ws, 0, 0, self.stride, 1, self.groups)?;
|
||||
let xs = match &self.bs {
|
||||
None => xs,
|
||||
Some(bias) => {
|
||||
let b = bias.dims1()?;
|
||||
let bias = bias.reshape((1, b, 1))?;
|
||||
xs.broadcast_add(&bias)?
|
||||
}
|
||||
};
|
||||
match self.norm.as_ref() {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_extra_padding_for_conv1d(
|
||||
xs: &Tensor,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
padding_total: usize,
|
||||
) -> Result<usize> {
|
||||
let len = xs.dim(D::Minus1)?;
|
||||
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
|
||||
let ideal_len =
|
||||
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
|
||||
Ok(ideal_len.saturating_sub(len))
|
||||
}
|
||||
|
||||
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
|
||||
match mode {
|
||||
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
|
||||
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
|
||||
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
|
||||
}
|
||||
}
|
||||
|
||||
fn unpad1d(xs: &Tensor, unpad_l: usize, unpad_r: usize) -> Result<Tensor> {
|
||||
let len = xs.dim(D::Minus1)?;
|
||||
if len < unpad_l + unpad_r {
|
||||
candle::bail!("unpad1d: tensor len {len} is too low, {unpad_l} + {unpad_r}")
|
||||
}
|
||||
xs.narrow(D::Minus1, unpad_l, len - (unpad_l + unpad_r))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamableConv1d {
|
||||
conv: NormConv1d,
|
||||
causal: bool,
|
||||
pad_mode: PadMode,
|
||||
state_prev_xs: StreamTensor,
|
||||
left_pad_applied: bool,
|
||||
kernel_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamableConv1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
bias: bool,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
pad_mode: PadMode,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
};
|
||||
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
||||
if k_size < stride {
|
||||
candle::bail!("kernel-size {k_size} is smaller than stride {stride}")
|
||||
}
|
||||
Ok(Self {
|
||||
conv,
|
||||
causal,
|
||||
pad_mode,
|
||||
state_prev_xs: StreamTensor::empty(),
|
||||
left_pad_applied: false,
|
||||
kernel_size: k_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "streamable-conv1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for StreamableConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b, _t, _c) = xs.dims3()?;
|
||||
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
|
||||
let conv_cfg = self.conv.conv.config();
|
||||
// Effective kernel size with dilations.
|
||||
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
|
||||
let padding_total = k_size - conv_cfg.stride;
|
||||
let extra_padding =
|
||||
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
|
||||
let xs = if self.causal {
|
||||
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
|
||||
} else {
|
||||
let padding_right = padding_total / 2;
|
||||
let padding_left = padding_total - padding_right;
|
||||
pad1d(
|
||||
xs,
|
||||
padding_left,
|
||||
padding_right + extra_padding,
|
||||
self.pad_mode,
|
||||
)?
|
||||
};
|
||||
xs.apply(&self.conv)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for StreamableConv1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.state_prev_xs.reset();
|
||||
self.left_pad_applied = false;
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = match xs.as_option() {
|
||||
None => return Ok(().into()),
|
||||
Some(xs) => xs.clone(),
|
||||
};
|
||||
let xs = if self.left_pad_applied {
|
||||
xs
|
||||
} else {
|
||||
self.left_pad_applied = true;
|
||||
let k_size = self.conv.conv.weight().dim(D::Minus1)?;
|
||||
let conv_cfg = self.conv.conv.config();
|
||||
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
|
||||
let padding_total = k_size - conv_cfg.stride;
|
||||
pad1d(&xs, padding_total, 0, self.pad_mode)?
|
||||
};
|
||||
let cfg = self.conv.conv.config();
|
||||
let stride = cfg.stride;
|
||||
let dilation = cfg.dilation;
|
||||
let kernel = (self.kernel_size - 1) * dilation + 1;
|
||||
let xs = StreamTensor::cat2(&self.state_prev_xs, &xs.into(), D::Minus1)?;
|
||||
let seq_len = xs.seq_len(D::Minus1)?;
|
||||
let num_frames = (seq_len + stride).saturating_sub(kernel) / stride;
|
||||
if num_frames > 0 {
|
||||
let offset = num_frames * stride;
|
||||
self.state_prev_xs = xs.narrow(D::Minus1, offset, seq_len - offset)?;
|
||||
let in_l = (num_frames - 1) * stride + kernel;
|
||||
let xs = xs.narrow(D::Minus1, 0, in_l)?;
|
||||
// We apply the underlying convtr directly rather than through forward so as
|
||||
// not to apply any padding here.
|
||||
xs.apply(&self.conv.conv)
|
||||
} else {
|
||||
self.state_prev_xs = xs;
|
||||
Ok(StreamTensor::empty())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamableConvTranspose1d {
|
||||
convtr: NormConvTranspose1d,
|
||||
causal: bool,
|
||||
state_prev_ys: StreamTensor,
|
||||
kernel_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamableConvTranspose1d {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
bias: bool,
|
||||
causal: bool,
|
||||
norm: Option<Norm>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let convtr =
|
||||
NormConvTranspose1d::new(in_c, out_c, k_size, causal, norm, bias, stride, groups, vb)?;
|
||||
Ok(Self {
|
||||
convtr,
|
||||
causal,
|
||||
kernel_size: k_size,
|
||||
state_prev_ys: StreamTensor::empty(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "streamable-conv-tr1d"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for StreamableConvTranspose1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let k_size = self.convtr.k_size;
|
||||
let stride = self.convtr.stride;
|
||||
let padding_total = k_size.saturating_sub(stride);
|
||||
let xs = xs.apply(&self.convtr)?;
|
||||
if self.causal {
|
||||
// This corresponds to trim_right_ratio = 1.
|
||||
unpad1d(&xs, 0, padding_total)
|
||||
} else {
|
||||
let padding_right = padding_total / 2;
|
||||
let padding_left = padding_total - padding_right;
|
||||
unpad1d(&xs, padding_left, padding_right)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for StreamableConvTranspose1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.state_prev_ys.reset()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = match xs.as_option() {
|
||||
Some(xs) => xs,
|
||||
None => return Ok(StreamTensor::empty()),
|
||||
};
|
||||
let stride = self.convtr.stride;
|
||||
// We apply the underlying convtr directly rather than through forward so as
|
||||
// not to apply any padding here.
|
||||
let ys = self.convtr.forward(xs)?;
|
||||
let ot = ys.dim(D::Minus1)?;
|
||||
let ys = match self.state_prev_ys.as_option() {
|
||||
None => ys,
|
||||
Some(prev_ys) => {
|
||||
let pt = prev_ys.dim(D::Minus1)?;
|
||||
// Remove the bias as it will be applied multiple times.
|
||||
let prev_ys = match &self.convtr.bs {
|
||||
None => prev_ys.clone(),
|
||||
Some(bias) => {
|
||||
let bias = bias.reshape((1, (), 1))?;
|
||||
prev_ys.broadcast_sub(&bias)?
|
||||
}
|
||||
};
|
||||
let ys1 = (ys.narrow(D::Minus1, 0, pt)? + prev_ys)?;
|
||||
let ys2 = ys.narrow(D::Minus1, pt, ot - pt)?;
|
||||
Tensor::cat(&[ys1, ys2], D::Minus1)?
|
||||
}
|
||||
};
|
||||
let invalid_steps = self.kernel_size - stride;
|
||||
let (ys, prev_ys) = StreamTensor::from(ys).split(D::Minus1, ot - invalid_steps)?;
|
||||
self.state_prev_ys = prev_ys;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConvDownsample1d {
|
||||
conv: StreamableConv1d,
|
||||
}
|
||||
|
||||
impl ConvDownsample1d {
|
||||
pub fn new(
|
||||
stride: usize,
|
||||
dim: usize,
|
||||
causal: bool,
|
||||
learnt: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if !learnt {
|
||||
candle::bail!("only learnt=true is supported")
|
||||
}
|
||||
let conv = StreamableConv1d::new(
|
||||
/* in_c */ dim,
|
||||
/* out_c */ dim,
|
||||
/* k_size_c */ 2 * stride,
|
||||
/* stride */ stride,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1, // channel_wise = false
|
||||
/* bias */ false,
|
||||
/* causal */ causal,
|
||||
/* norm */ None,
|
||||
/* pad_mode */ PadMode::Replicate,
|
||||
vb,
|
||||
)?;
|
||||
Ok(Self { conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvDownsample1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.conv)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for ConvDownsample1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.conv.reset_state()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
self.conv.step(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConvTrUpsample1d {
|
||||
convtr: StreamableConvTranspose1d,
|
||||
}
|
||||
|
||||
impl ConvTrUpsample1d {
|
||||
pub fn new(
|
||||
stride: usize,
|
||||
dim: usize,
|
||||
causal: bool,
|
||||
learnt: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if !learnt {
|
||||
candle::bail!("only learnt=true is supported")
|
||||
}
|
||||
let convtr = StreamableConvTranspose1d::new(
|
||||
dim,
|
||||
dim,
|
||||
/* k_size */ 2 * stride,
|
||||
/* stride */ stride,
|
||||
/* groups */ dim,
|
||||
/* bias */ false,
|
||||
/* causal */ causal,
|
||||
/* norm */ None,
|
||||
vb,
|
||||
)?;
|
||||
Ok(Self { convtr })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvTrUpsample1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.convtr)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for ConvTrUpsample1d {
|
||||
fn reset_state(&mut self) {
|
||||
self.convtr.reset_state()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
self.convtr.step(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use candle::IndexOp;
|
||||
|
||||
fn run_conv1d(
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
step_size: usize,
|
||||
len: usize,
|
||||
bias: bool,
|
||||
) -> Result<()> {
|
||||
// TODO: We should ensure for the seed to be constant when running these tests.
|
||||
let dev = &candle::Device::Cpu;
|
||||
let vm = candle_nn::VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
|
||||
let conv1d = StreamableConv1d::new(
|
||||
/* in_c */ 2,
|
||||
/* out_c */ 3,
|
||||
/* k_size */ k_size,
|
||||
/* stride */ stride,
|
||||
/* dilation */ dilation,
|
||||
/* groups */ 1,
|
||||
/* bias */ bias,
|
||||
/* causal */ true,
|
||||
/* norm */ None,
|
||||
/* pad_mode */ PadMode::Constant,
|
||||
vb,
|
||||
)?;
|
||||
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
|
||||
let ys = conv1d.forward(&xs)?;
|
||||
let mut conv1d = conv1d;
|
||||
let mut ys_steps = vec![];
|
||||
for idx in 0..len {
|
||||
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
|
||||
let ys = conv1d.step(&xs.into())?;
|
||||
if let Some(ys) = ys.as_option() {
|
||||
ys_steps.push(ys.clone())
|
||||
}
|
||||
}
|
||||
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
|
||||
let diff = (&ys - &ys_steps)?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?
|
||||
.to_vec0::<f32>()?;
|
||||
if diff > 1e-5 {
|
||||
println!("{xs}");
|
||||
println!("{ys}");
|
||||
println!("{ys_steps}");
|
||||
candle::bail!("larger diff than expected {diff}")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_conv_tr1d(
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
step_size: usize,
|
||||
len: usize,
|
||||
bias: bool,
|
||||
) -> Result<()> {
|
||||
// TODO: We should ensure for the seed to be constant when running these tests.
|
||||
let dev = &candle::Device::Cpu;
|
||||
let vm = candle_nn::VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&vm, candle::DType::F32, dev);
|
||||
let conv1d = StreamableConvTranspose1d::new(
|
||||
/* in_c */ 2, /* out_c */ 3, /* k_size */ k_size,
|
||||
/* stride */ stride, /* groups */ 1, /* bias */ bias,
|
||||
/* causal */ true, /* norm */ None, vb,
|
||||
)?;
|
||||
let xs = Tensor::randn(0f32, 1., (1, 2, step_size * len), dev)?;
|
||||
let ys = conv1d.forward(&xs)?;
|
||||
let mut conv1d = conv1d;
|
||||
let mut ys_steps = vec![];
|
||||
for idx in 0..len {
|
||||
let xs = xs.i((.., .., step_size * idx..step_size * (idx + 1)))?;
|
||||
let ys = conv1d.step(&xs.into())?;
|
||||
if let Some(ys) = ys.as_option() {
|
||||
ys_steps.push(ys.clone())
|
||||
}
|
||||
}
|
||||
let ys_steps = Tensor::cat(&ys_steps, D::Minus1)?;
|
||||
let diff = (&ys - &ys_steps)?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?
|
||||
.to_vec0::<f32>()?;
|
||||
if diff > 1e-5 {
|
||||
println!("{xs}");
|
||||
println!("{ys}");
|
||||
println!("{ys_steps}");
|
||||
candle::bail!("larger diff than expected {diff}")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d() -> Result<()> {
|
||||
for step_size in [1, 2, 3] {
|
||||
for bias in [false, true] {
|
||||
run_conv1d(1, 1, 1, step_size, 5, bias)?;
|
||||
run_conv1d(2, 1, 1, step_size, 5, bias)?;
|
||||
run_conv1d(2, 2, 1, step_size, 6, bias)?;
|
||||
run_conv1d(3, 2, 1, step_size, 8, bias)?;
|
||||
run_conv1d(3, 2, 2, step_size, 8, bias)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_tr1d() -> Result<()> {
|
||||
for step_size in [1, 2, 3] {
|
||||
for bias in [false, true] {
|
||||
run_conv_tr1d(1, 1, step_size, 5, bias)?;
|
||||
run_conv_tr1d(2, 1, step_size, 5, bias)?;
|
||||
run_conv_tr1d(3, 1, step_size, 5, bias)?;
|
||||
run_conv_tr1d(3, 2, step_size, 5, bias)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
229
candle-transformers/src/models/mimi/encodec.rs
Normal file
229
candle-transformers/src/models/mimi/encodec.rs
Normal file
@ -0,0 +1,229 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use super::{conv, quantization, seanet, transformer};
|
||||
use candle::{DType, Device, Module, Result, StreamTensor, StreamingModule, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum ResampleMethod {
|
||||
Conv,
|
||||
Interpolate,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub channels: usize,
|
||||
pub sample_rate: f64,
|
||||
pub frame_rate: f64,
|
||||
pub renormalize: bool,
|
||||
pub resample_method: ResampleMethod,
|
||||
pub seanet: seanet::Config,
|
||||
pub transformer: transformer::Config,
|
||||
pub quantizer_n_q: usize,
|
||||
pub quantizer_bins: usize,
|
||||
pub quantizer_dim: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// /lustre/scwpod02/client/kyutai/alex/mimi_exp/xps/b7d2bd5a/.hydra/config.yaml
|
||||
pub fn v0_1(num_codebooks: Option<usize>) -> Self {
|
||||
let seanet_cfg = seanet::Config {
|
||||
dimension: 512,
|
||||
channels: 1,
|
||||
causal: true,
|
||||
n_filters: 64,
|
||||
n_residual_layers: 1,
|
||||
activation: candle_nn::Activation::Elu(1.),
|
||||
compress: 2,
|
||||
dilation_base: 2,
|
||||
disable_norm_outer_blocks: 0,
|
||||
final_activation: None,
|
||||
kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
last_kernel_size: 3,
|
||||
lstm: 0,
|
||||
norm: conv::Norm::WeightNorm,
|
||||
pad_mode: conv::PadMode::Constant,
|
||||
ratios: vec![8, 6, 5, 4],
|
||||
true_skip: true,
|
||||
};
|
||||
let transformer_cfg = transformer::Config {
|
||||
d_model: seanet_cfg.dimension,
|
||||
num_heads: 8,
|
||||
num_layers: 8,
|
||||
causal: true,
|
||||
norm_first: true,
|
||||
bias_ff: false,
|
||||
bias_attn: false,
|
||||
layer_scale: Some(0.01),
|
||||
context: 250,
|
||||
conv_kernel_size: 5,
|
||||
use_conv_bias: true,
|
||||
use_conv_block: false,
|
||||
cross_attention: false,
|
||||
max_period: 10000,
|
||||
gating: None,
|
||||
norm: super::NormType::LayerNorm,
|
||||
positional_embedding: transformer::PositionalEmbedding::Rope,
|
||||
|
||||
dim_feedforward: 2048,
|
||||
kv_repeat: 1,
|
||||
conv_layout: true, // see builders.py
|
||||
max_seq_len: 8192, // the transformer works at 25hz so this is ~5 mins.
|
||||
};
|
||||
Config {
|
||||
channels: 1,
|
||||
sample_rate: 24_000.,
|
||||
frame_rate: 12.5,
|
||||
renormalize: true,
|
||||
resample_method: ResampleMethod::Conv,
|
||||
seanet: seanet_cfg,
|
||||
transformer: transformer_cfg,
|
||||
quantizer_n_q: num_codebooks.unwrap_or(16),
|
||||
quantizer_bins: 2048,
|
||||
quantizer_dim: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encodec {
|
||||
encoder: seanet::SeaNetEncoder,
|
||||
decoder: seanet::SeaNetDecoder,
|
||||
encoder_transformer: transformer::ProjectedTransformer,
|
||||
decoder_transformer: transformer::ProjectedTransformer,
|
||||
downsample: conv::ConvDownsample1d,
|
||||
upsample: conv::ConvTrUpsample1d,
|
||||
quantizer: quantization::SplitResidualVectorQuantizer,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Encodec {
|
||||
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dim = cfg.seanet.dimension;
|
||||
let encoder = seanet::SeaNetEncoder::new(&cfg.seanet, vb.pp("encoder"))?;
|
||||
let decoder = seanet::SeaNetDecoder::new(&cfg.seanet, vb.pp("decoder"))?;
|
||||
let encoder_transformer = transformer::ProjectedTransformer::new(
|
||||
dim,
|
||||
&[dim],
|
||||
&cfg.transformer,
|
||||
vb.pp("encoder_transformer"),
|
||||
)?;
|
||||
let decoder_transformer = transformer::ProjectedTransformer::new(
|
||||
dim,
|
||||
&[dim],
|
||||
&cfg.transformer,
|
||||
vb.pp("decoder_transformer"),
|
||||
)?;
|
||||
let quantizer = quantization::SplitResidualVectorQuantizer::new(
|
||||
/* dim */ cfg.quantizer_dim,
|
||||
/* input_dim */ Some(dim),
|
||||
/* output_dim */ Some(dim),
|
||||
/* n_q */ cfg.quantizer_n_q,
|
||||
/* bins */ cfg.quantizer_bins,
|
||||
vb.pp("quantizer"),
|
||||
)?;
|
||||
let encoder_frame_rate =
|
||||
cfg.sample_rate / cfg.seanet.ratios.iter().product::<usize>() as f64;
|
||||
|
||||
let downsample_stride = (encoder_frame_rate / cfg.frame_rate) as usize;
|
||||
// `upsample` and `downsample` only apply if frame_rate is different from encoder_frame_rate.
|
||||
let downsample = conv::ConvDownsample1d::new(
|
||||
/* stride */ downsample_stride,
|
||||
/* dim */ dim,
|
||||
/* causal */ true,
|
||||
/* learnt */ true,
|
||||
vb.pp("downsample"),
|
||||
)?;
|
||||
let upsample = conv::ConvTrUpsample1d::new(
|
||||
/* stride */ downsample_stride,
|
||||
/* dim */ dim,
|
||||
/* causal */ true,
|
||||
/* learnt */ true,
|
||||
vb.pp("upsample"),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
encoder_transformer,
|
||||
decoder_transformer,
|
||||
quantizer,
|
||||
downsample,
|
||||
upsample,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn encode_pre_quantize(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
self.encoder_transformer.reset_state();
|
||||
let xs = self.encoder_transformer.forward(&xs)?;
|
||||
let xs = &xs[0];
|
||||
xs.apply(&self.downsample)
|
||||
}
|
||||
|
||||
pub fn encode(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
self.encoder_transformer.reset_state();
|
||||
let xs = self.encoder_transformer.forward(&xs)?;
|
||||
let xs = &xs[0];
|
||||
let xs = xs.apply(&self.downsample)?;
|
||||
let codes = self.quantizer.encode(&xs)?;
|
||||
Ok(codes)
|
||||
}
|
||||
|
||||
pub fn encode_step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let xs = self.encoder.step(xs)?;
|
||||
let xs = self.encoder_transformer.step(&xs)?;
|
||||
let xs = self.downsample.step(&xs)?;
|
||||
match xs.as_option() {
|
||||
None => Ok(().into()),
|
||||
Some(xs) => {
|
||||
let codes = self.quantizer.encode(xs)?;
|
||||
Ok(codes.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, codes: &Tensor) -> Result<Tensor> {
|
||||
let emb = self.quantizer.decode(codes)?;
|
||||
let emb = emb.apply(&self.upsample)?;
|
||||
self.decoder_transformer.reset_state();
|
||||
let outs = self.decoder_transformer.forward(&emb)?;
|
||||
let out = &outs[0];
|
||||
self.decoder.forward(out)
|
||||
}
|
||||
|
||||
pub fn decode_step(&mut self, codes: &StreamTensor) -> Result<StreamTensor> {
|
||||
let emb = match codes.as_option() {
|
||||
Some(codes) => StreamTensor::from_tensor(self.quantizer.decode(codes)?),
|
||||
None => StreamTensor::empty(),
|
||||
};
|
||||
let emb = self.upsample.step(&emb)?;
|
||||
let out = self.decoder_transformer.step(&emb)?;
|
||||
self.decoder.step(&out)
|
||||
}
|
||||
|
||||
pub fn reset_state(&mut self) {
|
||||
self.encoder.reset_state();
|
||||
self.encoder_transformer.reset_state();
|
||||
self.decoder.reset_state();
|
||||
self.decoder_transformer.reset_state();
|
||||
self.upsample.reset_state();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load(model_file: &str, num_codebooks: Option<usize>, dev: &Device) -> Result<Encodec> {
|
||||
let vb =
|
||||
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, dev)? };
|
||||
let cfg = Config::v0_1(num_codebooks);
|
||||
let encodec = Encodec::new(cfg, vb)?;
|
||||
Ok(encodec)
|
||||
}
|
22
candle-transformers/src/models/mimi/mod.rs
Normal file
22
candle-transformers/src/models/mimi/mod.rs
Normal file
@ -0,0 +1,22 @@
|
||||
// Adapted from the reference implementation at:
|
||||
// https://github.com/kyutai-labs/moshi
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
pub use candle;
|
||||
pub use candle_nn;
|
||||
|
||||
pub mod conv;
|
||||
pub mod encodec;
|
||||
pub mod quantization;
|
||||
pub mod seanet;
|
||||
pub mod transformer;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum NormType {
|
||||
RmsNorm,
|
||||
LayerNorm,
|
||||
}
|
||||
|
||||
pub use encodec::{load, Config, Encodec as Model};
|
404
candle-transformers/src/models/mimi/quantization.rs
Normal file
404
candle-transformers/src/models/mimi/quantization.rs
Normal file
@ -0,0 +1,404 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{linear, Linear, VarBuilder};
|
||||
|
||||
struct CodebookEncode;
|
||||
|
||||
impl candle::CustomOp2 for CodebookEncode {
|
||||
fn name(&self) -> &'static str {
|
||||
"cb"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
lhs_storage: &candle::CpuStorage,
|
||||
lhs_layout: &Layout,
|
||||
rhs_storage: &candle::CpuStorage,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<(candle::CpuStorage, Shape)> {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
|
||||
let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
|
||||
if lhs_dim2 != rhs_dim2 {
|
||||
candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
|
||||
}
|
||||
if lhs_dim2 == 0 {
|
||||
candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
|
||||
}
|
||||
let lhs = match lhs_layout.contiguous_offsets() {
|
||||
None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
|
||||
Some((o1, o2)) => {
|
||||
let slice = lhs_storage.as_slice::<f32>()?;
|
||||
&slice[o1..o2]
|
||||
}
|
||||
};
|
||||
let rhs = match rhs_layout.contiguous_offsets() {
|
||||
None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
|
||||
Some((o1, o2)) => {
|
||||
let slice = rhs_storage.as_slice::<f32>()?;
|
||||
&slice[o1..o2]
|
||||
}
|
||||
};
|
||||
let dst = (0..lhs_dim1)
|
||||
.into_par_iter()
|
||||
.map(|idx1| {
|
||||
let mut where_min = 0;
|
||||
let mut min_dist = f32::INFINITY;
|
||||
let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
|
||||
for idx2 in 0..rhs_dim1 {
|
||||
let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
|
||||
let mut dist = 0f32;
|
||||
for (a, b) in lhs.iter().zip(rhs.iter()) {
|
||||
dist += (a - b) * (a - b)
|
||||
}
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
where_min = idx2;
|
||||
}
|
||||
}
|
||||
where_min as u32
|
||||
})
|
||||
.collect();
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (lhs_dim1,).into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EuclideanCodebook {
|
||||
initialized: Tensor,
|
||||
cluster_usage: Tensor,
|
||||
embedding_sum: Tensor,
|
||||
embedding: Tensor,
|
||||
c2: Tensor,
|
||||
epsilon: f64,
|
||||
dim: usize,
|
||||
span_encode: tracing::Span,
|
||||
span_decode: tracing::Span,
|
||||
}
|
||||
|
||||
impl EuclideanCodebook {
|
||||
pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let epsilon = 1e-5;
|
||||
let initialized = vb.get(1, "initialized")?;
|
||||
let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
|
||||
let embedding_sum = vb.get((codebook_size, dim), "embed_sum")?;
|
||||
let embedding = {
|
||||
let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
|
||||
embedding_sum.broadcast_div(&cluster_usage)?
|
||||
};
|
||||
let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
|
||||
Ok(Self {
|
||||
initialized,
|
||||
cluster_usage,
|
||||
embedding_sum,
|
||||
embedding,
|
||||
c2,
|
||||
epsilon,
|
||||
dim,
|
||||
span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
|
||||
span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
// TODO: avoid repeating this.
|
||||
let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
|
||||
let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
|
||||
// Manual cdist implementation.
|
||||
let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
|
||||
let dists = diff.sqr()?.sum(D::Minus1)?;
|
||||
let codes = dists.argmin(D::Minus1)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let dot_prod = xs.matmul(&self.embedding.t()?)?;
|
||||
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_decode.enter();
|
||||
// let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?;
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.dim);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = self.embedding.index_select(&indexes, 0)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorQuantization {
|
||||
project_in: Option<Linear>,
|
||||
project_out: Option<Linear>,
|
||||
codebook: EuclideanCodebook,
|
||||
}
|
||||
|
||||
impl VectorQuantization {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let codebook_dim = codebook_dim.unwrap_or(dim);
|
||||
let (project_in, project_out) = if codebook_dim == dim {
|
||||
(None, None)
|
||||
} else {
|
||||
let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
|
||||
let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
|
||||
(Some(p_in), Some(p_out))
|
||||
};
|
||||
let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("codebook"))?;
|
||||
Ok(Self {
|
||||
project_in,
|
||||
project_out,
|
||||
codebook,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.t()?.apply(&self.project_in.as_ref())?;
|
||||
self.codebook.encode_slow(&xs)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let quantized = self.codebook.decode(codes)?;
|
||||
let quantized = match &self.project_out {
|
||||
None => quantized,
|
||||
Some(p) => quantized.apply(p)?,
|
||||
};
|
||||
quantized.t()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResidualVectorQuantization {
|
||||
layers: Vec<VectorQuantization>,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantization {
|
||||
pub fn new(
|
||||
n_q: usize,
|
||||
dim: usize,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(n_q);
|
||||
for i in 0..n_q {
|
||||
let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut codes = Vec::with_capacity(self.layers.len());
|
||||
let mut residual = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let indices = layer.encode(&residual)?;
|
||||
let quantized = layer.decode(&indices)?;
|
||||
residual = (residual - quantized)?;
|
||||
codes.push(indices)
|
||||
}
|
||||
Tensor::stack(&codes, 0)
|
||||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
if self.layers.is_empty() {
|
||||
candle::bail!("empty layers in ResidualVectorQuantization")
|
||||
}
|
||||
if self.layers.len() != xs.dim(0)? {
|
||||
candle::bail!(
|
||||
"mismatch between the number of layers {} and the code shape {:?}",
|
||||
self.layers.len(),
|
||||
xs.shape()
|
||||
)
|
||||
}
|
||||
let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
|
||||
for (i, layer) in self.layers.iter().enumerate().skip(1) {
|
||||
let xs = xs.i(i)?;
|
||||
quantized = (quantized + layer.decode(&xs))?
|
||||
}
|
||||
Ok(quantized)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResidualVectorQuantizer {
|
||||
vq: ResidualVectorQuantization,
|
||||
input_proj: Option<candle_nn::Conv1d>,
|
||||
output_proj: Option<candle_nn::Conv1d>,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantizer {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
input_dim: Option<usize>,
|
||||
output_dim: Option<usize>,
|
||||
n_q: usize,
|
||||
bins: usize,
|
||||
force_projection: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let input_dim = input_dim.unwrap_or(dim);
|
||||
let output_dim = output_dim.unwrap_or(dim);
|
||||
|
||||
let input_proj = if input_dim == dim && !force_projection {
|
||||
None
|
||||
} else {
|
||||
let c = candle_nn::conv1d_no_bias(
|
||||
input_dim,
|
||||
dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("input_proj"),
|
||||
)?;
|
||||
Some(c)
|
||||
};
|
||||
let output_proj = if output_dim == dim && !force_projection {
|
||||
None
|
||||
} else {
|
||||
let c = candle_nn::conv1d_no_bias(
|
||||
dim,
|
||||
output_dim,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("output_proj"),
|
||||
)?;
|
||||
Some(c)
|
||||
};
|
||||
|
||||
let vq = ResidualVectorQuantization::new(
|
||||
n_q, dim, /* codebook_size */ bins, /* codebook_dim */ None, vb,
|
||||
)?;
|
||||
Ok(Self {
|
||||
vq,
|
||||
input_proj,
|
||||
output_proj,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
|
||||
codes.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
// codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
||||
let codes = codes.transpose(0, 1)?;
|
||||
let quantized = self.vq.decode(&codes)?;
|
||||
match &self.output_proj {
|
||||
None => Ok(quantized),
|
||||
Some(p) => quantized.apply(p),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just
|
||||
// concatenate the indexes.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SplitResidualVectorQuantizer {
|
||||
rvq_first: ResidualVectorQuantizer,
|
||||
rvq_rest: ResidualVectorQuantizer,
|
||||
n_q: usize,
|
||||
span_encode: tracing::Span,
|
||||
span_decode: tracing::Span,
|
||||
}
|
||||
|
||||
impl SplitResidualVectorQuantizer {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
input_dim: Option<usize>,
|
||||
output_dim: Option<usize>,
|
||||
n_q: usize,
|
||||
bins: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let rvq_first = ResidualVectorQuantizer::new(
|
||||
dim,
|
||||
input_dim,
|
||||
output_dim,
|
||||
1,
|
||||
bins,
|
||||
true,
|
||||
vb.pp("semantic_residual_vector_quantizer"),
|
||||
)?;
|
||||
let rvq_rest = ResidualVectorQuantizer::new(
|
||||
dim,
|
||||
input_dim,
|
||||
output_dim,
|
||||
n_q - 1,
|
||||
bins,
|
||||
true,
|
||||
vb.pp("acoustic_residual_vector_quantizer"),
|
||||
)?;
|
||||
let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
|
||||
let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
|
||||
Ok(Self {
|
||||
rvq_first,
|
||||
rvq_rest,
|
||||
n_q,
|
||||
span_encode,
|
||||
span_decode,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_encode.enter();
|
||||
let codes = self.rvq_first.encode(xs)?;
|
||||
if self.n_q > 1 {
|
||||
// We encode xs again here rather than the residual. The decomposition is not
|
||||
// hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens
|
||||
// for rvq_rest.
|
||||
let rest_codes = self.rvq_rest.encode(xs)?;
|
||||
Tensor::cat(&[codes, rest_codes], 1)
|
||||
} else {
|
||||
Ok(codes)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
// codes is [B, K, T], with T frames, K nb of codebooks.
|
||||
let _enter = self.span_decode.enter();
|
||||
let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
|
||||
let quantized = if self.n_q > 1 {
|
||||
(quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
|
||||
} else {
|
||||
quantized
|
||||
};
|
||||
Ok(quantized)
|
||||
}
|
||||
}
|
465
candle-transformers/src/models/mimi/seanet.rs
Normal file
465
candle-transformers/src/models/mimi/seanet.rs
Normal file
@ -0,0 +1,465 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{streaming, Module, Result, StreamTensor, StreamingModule, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use super::conv::{StreamableConv1d, StreamableConvTranspose1d};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub dimension: usize,
|
||||
pub channels: usize,
|
||||
pub causal: bool,
|
||||
pub n_filters: usize,
|
||||
pub n_residual_layers: usize,
|
||||
pub ratios: Vec<usize>,
|
||||
pub activation: candle_nn::Activation,
|
||||
pub norm: super::conv::Norm,
|
||||
pub kernel_size: usize,
|
||||
pub residual_kernel_size: usize,
|
||||
pub last_kernel_size: usize,
|
||||
pub dilation_base: usize,
|
||||
pub pad_mode: super::conv::PadMode,
|
||||
pub true_skip: bool,
|
||||
pub compress: usize,
|
||||
pub lstm: usize,
|
||||
pub disable_norm_outer_blocks: usize,
|
||||
pub final_activation: Option<candle_nn::Activation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SeaNetResnetBlock {
|
||||
block: Vec<StreamableConv1d>,
|
||||
shortcut: Option<StreamableConv1d>,
|
||||
activation: candle_nn::Activation,
|
||||
skip_op: candle::StreamingBinOp,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl SeaNetResnetBlock {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
k_sizes_and_dilations: &[(usize, usize)],
|
||||
activation: candle_nn::Activation,
|
||||
norm: Option<super::conv::Norm>,
|
||||
causal: bool,
|
||||
pad_mode: super::conv::PadMode,
|
||||
compress: usize,
|
||||
true_skip: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut block = Vec::with_capacity(k_sizes_and_dilations.len());
|
||||
let hidden = dim / compress;
|
||||
let vb_b = vb.pp("block");
|
||||
for (i, (k_size, dilation)) in k_sizes_and_dilations.iter().enumerate() {
|
||||
let in_c = if i == 0 { dim } else { hidden };
|
||||
let out_c = if i == k_sizes_and_dilations.len() - 1 {
|
||||
dim
|
||||
} else {
|
||||
hidden
|
||||
};
|
||||
let c = StreamableConv1d::new(
|
||||
in_c,
|
||||
out_c,
|
||||
/* k_size */ *k_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ *dilation,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ causal,
|
||||
/* norm */ norm,
|
||||
/* pad_mode */ pad_mode,
|
||||
vb_b.pp(2 * i + 1),
|
||||
)?;
|
||||
block.push(c)
|
||||
}
|
||||
let shortcut = if true_skip {
|
||||
None
|
||||
} else {
|
||||
let c = StreamableConv1d::new(
|
||||
dim,
|
||||
dim,
|
||||
/* k_size */ 1,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ causal,
|
||||
/* norm */ norm,
|
||||
/* pad_mode */ pad_mode,
|
||||
vb.pp("shortcut"),
|
||||
)?;
|
||||
Some(c)
|
||||
};
|
||||
Ok(Self {
|
||||
block,
|
||||
shortcut,
|
||||
activation,
|
||||
skip_op: streaming::StreamingBinOp::new(streaming::BinOp::Add, candle::D::Minus1),
|
||||
span: tracing::span!(tracing::Level::TRACE, "sea-resnet"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SeaNetResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut ys = xs.clone();
|
||||
for block in self.block.iter() {
|
||||
ys = ys.apply(&self.activation)?.apply(block)?;
|
||||
}
|
||||
match self.shortcut.as_ref() {
|
||||
None => ys + xs,
|
||||
Some(shortcut) => ys + xs.apply(shortcut),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for SeaNetResnetBlock {
|
||||
fn reset_state(&mut self) {
|
||||
for block in self.block.iter_mut() {
|
||||
block.reset_state()
|
||||
}
|
||||
if let Some(shortcut) = self.shortcut.as_mut() {
|
||||
shortcut.reset_state()
|
||||
}
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut ys = xs.clone();
|
||||
for block in self.block.iter_mut() {
|
||||
ys = block.step(&ys.apply(&self.activation)?)?;
|
||||
}
|
||||
match self.shortcut.as_ref() {
|
||||
None => self.skip_op.step(&ys, xs),
|
||||
Some(shortcut) => self.skip_op.step(&ys, &xs.apply(shortcut)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderLayer {
|
||||
residuals: Vec<SeaNetResnetBlock>,
|
||||
downsample: StreamableConv1d,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SeaNetEncoder {
|
||||
init_conv1d: StreamableConv1d,
|
||||
activation: candle_nn::Activation,
|
||||
layers: Vec<EncoderLayer>,
|
||||
final_conv1d: StreamableConv1d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl SeaNetEncoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
if cfg.lstm > 0 {
|
||||
candle::bail!("seanet lstm is not supported")
|
||||
}
|
||||
let n_blocks = 2 + cfg.ratios.len();
|
||||
let mut mult = 1usize;
|
||||
let init_norm = if cfg.disable_norm_outer_blocks >= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let mut layer_idx = 0;
|
||||
let vb = vb.pp("layers");
|
||||
let init_conv1d = StreamableConv1d::new(
|
||||
cfg.channels,
|
||||
mult * cfg.n_filters,
|
||||
cfg.kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ init_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
layer_idx += 1;
|
||||
let mut layers = Vec::with_capacity(cfg.ratios.len());
|
||||
|
||||
for (i, &ratio) in cfg.ratios.iter().rev().enumerate() {
|
||||
let norm = if cfg.disable_norm_outer_blocks >= i + 2 {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
|
||||
for j in 0..cfg.n_residual_layers {
|
||||
let resnet_block = SeaNetResnetBlock::new(
|
||||
mult * cfg.n_filters,
|
||||
&[
|
||||
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
|
||||
(1, 1),
|
||||
],
|
||||
cfg.activation,
|
||||
norm,
|
||||
cfg.causal,
|
||||
cfg.pad_mode,
|
||||
cfg.compress,
|
||||
cfg.true_skip,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
residuals.push(resnet_block);
|
||||
layer_idx += 1;
|
||||
}
|
||||
let downsample = StreamableConv1d::new(
|
||||
mult * cfg.n_filters,
|
||||
mult * cfg.n_filters * 2,
|
||||
/* k_size */ ratio * 2,
|
||||
/* stride */ ratio,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ true,
|
||||
/* norm */ norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
layer_idx += 2;
|
||||
let layer = EncoderLayer {
|
||||
downsample,
|
||||
residuals,
|
||||
};
|
||||
layers.push(layer);
|
||||
mult *= 2
|
||||
}
|
||||
|
||||
let final_norm = if cfg.disable_norm_outer_blocks >= n_blocks {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let final_conv1d = StreamableConv1d::new(
|
||||
mult * cfg.n_filters,
|
||||
cfg.dimension,
|
||||
cfg.last_kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ final_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv1d,
|
||||
activation: cfg.activation,
|
||||
layers,
|
||||
final_conv1d,
|
||||
span: tracing::span!(tracing::Level::TRACE, "sea-encoder"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SeaNetEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.apply(&self.init_conv1d)?;
|
||||
for layer in self.layers.iter() {
|
||||
for residual in layer.residuals.iter() {
|
||||
xs = xs.apply(residual)?
|
||||
}
|
||||
xs = xs.apply(&self.activation)?.apply(&layer.downsample)?;
|
||||
}
|
||||
xs.apply(&self.activation)?.apply(&self.final_conv1d)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for SeaNetEncoder {
|
||||
fn reset_state(&mut self) {
|
||||
self.init_conv1d.reset_state();
|
||||
self.layers.iter_mut().for_each(|v| {
|
||||
v.residuals.iter_mut().for_each(|v| v.reset_state());
|
||||
v.downsample.reset_state()
|
||||
});
|
||||
self.final_conv1d.reset_state();
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.init_conv1d.step(xs)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
for residual in layer.residuals.iter_mut() {
|
||||
xs = residual.step(&xs)?;
|
||||
}
|
||||
xs = layer.downsample.step(&xs.apply(&self.activation)?)?;
|
||||
}
|
||||
self.final_conv1d.step(&xs.apply(&self.activation)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
upsample: StreamableConvTranspose1d,
|
||||
residuals: Vec<SeaNetResnetBlock>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SeaNetDecoder {
|
||||
init_conv1d: StreamableConv1d,
|
||||
activation: candle_nn::Activation,
|
||||
layers: Vec<DecoderLayer>,
|
||||
final_conv1d: StreamableConv1d,
|
||||
final_activation: Option<candle_nn::Activation>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl SeaNetDecoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
if cfg.lstm > 0 {
|
||||
candle::bail!("seanet lstm is not supported")
|
||||
}
|
||||
let n_blocks = 2 + cfg.ratios.len();
|
||||
let mut mult = 1 << cfg.ratios.len();
|
||||
let init_norm = if cfg.disable_norm_outer_blocks == n_blocks {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let mut layer_idx = 0;
|
||||
let vb = vb.pp("layers");
|
||||
let init_conv1d = StreamableConv1d::new(
|
||||
cfg.dimension,
|
||||
mult * cfg.n_filters,
|
||||
cfg.kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ init_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
layer_idx += 1;
|
||||
let mut layers = Vec::with_capacity(cfg.ratios.len());
|
||||
for (i, &ratio) in cfg.ratios.iter().enumerate() {
|
||||
let norm = if cfg.disable_norm_outer_blocks + i + 1 >= n_blocks {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let upsample = StreamableConvTranspose1d::new(
|
||||
mult * cfg.n_filters,
|
||||
mult * cfg.n_filters / 2,
|
||||
/* k_size */ ratio * 2,
|
||||
/* stride */ ratio,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ true,
|
||||
/* norm */ norm,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
layer_idx += 2;
|
||||
|
||||
let mut residuals = Vec::with_capacity(cfg.n_residual_layers);
|
||||
for j in 0..cfg.n_residual_layers {
|
||||
let resnet_block = SeaNetResnetBlock::new(
|
||||
mult * cfg.n_filters / 2,
|
||||
&[
|
||||
(cfg.residual_kernel_size, cfg.dilation_base.pow(j as u32)),
|
||||
(1, 1),
|
||||
],
|
||||
cfg.activation,
|
||||
norm,
|
||||
cfg.causal,
|
||||
cfg.pad_mode,
|
||||
cfg.compress,
|
||||
cfg.true_skip,
|
||||
vb.pp(layer_idx),
|
||||
)?;
|
||||
residuals.push(resnet_block);
|
||||
layer_idx += 1;
|
||||
}
|
||||
let layer = DecoderLayer {
|
||||
upsample,
|
||||
residuals,
|
||||
};
|
||||
layers.push(layer);
|
||||
mult /= 2
|
||||
}
|
||||
let final_norm = if cfg.disable_norm_outer_blocks >= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(cfg.norm)
|
||||
};
|
||||
let final_conv1d = StreamableConv1d::new(
|
||||
cfg.n_filters,
|
||||
cfg.channels,
|
||||
cfg.last_kernel_size,
|
||||
/* stride */ 1,
|
||||
/* dilation */ 1,
|
||||
/* groups */ 1,
|
||||
/* bias */ true,
|
||||
/* causal */ cfg.causal,
|
||||
/* norm */ final_norm,
|
||||
/* pad_mode */ cfg.pad_mode,
|
||||
vb.pp(layer_idx + 1),
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv1d,
|
||||
activation: cfg.activation,
|
||||
layers,
|
||||
final_conv1d,
|
||||
final_activation: cfg.final_activation,
|
||||
span: tracing::span!(tracing::Level::TRACE, "sea-decoder"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SeaNetDecoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.apply(&self.init_conv1d)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = xs.apply(&self.activation)?.apply(&layer.upsample)?;
|
||||
for residual in layer.residuals.iter() {
|
||||
xs = xs.apply(residual)?
|
||||
}
|
||||
}
|
||||
let xs = xs.apply(&self.activation)?.apply(&self.final_conv1d)?;
|
||||
let xs = match self.final_activation.as_ref() {
|
||||
None => xs,
|
||||
Some(act) => xs.apply(act)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for SeaNetDecoder {
|
||||
fn reset_state(&mut self) {
|
||||
self.init_conv1d.reset_state();
|
||||
self.layers.iter_mut().for_each(|v| {
|
||||
v.residuals.iter_mut().for_each(|v| v.reset_state());
|
||||
v.upsample.reset_state()
|
||||
});
|
||||
self.final_conv1d.reset_state();
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.init_conv1d.step(xs)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.upsample.step(&xs.apply(&self.activation)?)?;
|
||||
for residual in layer.residuals.iter_mut() {
|
||||
xs = residual.step(&xs)?;
|
||||
}
|
||||
}
|
||||
let xs = self.final_conv1d.step(&xs.apply(&self.activation)?)?;
|
||||
let xs = match self.final_activation.as_ref() {
|
||||
None => xs,
|
||||
Some(act) => xs.apply(act)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
802
candle-transformers/src/models/mimi/transformer.rs
Normal file
802
candle-transformers/src/models/mimi/transformer.rs
Normal file
@ -0,0 +1,802 @@
|
||||
// Copyright (c) Kyutai, all rights reserved.
|
||||
// This source code is licensed under the license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
use candle::{DType, Device, IndexOp, Module, Result, StreamTensor, StreamingModule, Tensor, D};
|
||||
use candle_nn::{linear_no_bias, Linear, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn linear(in_d: usize, out_d: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_d, out_d, vb)
|
||||
} else {
|
||||
linear_no_bias(in_d, out_d, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum PositionalEmbedding {
|
||||
Rope,
|
||||
Sin,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub d_model: usize,
|
||||
pub num_heads: usize,
|
||||
pub num_layers: usize,
|
||||
pub causal: bool,
|
||||
pub norm_first: bool,
|
||||
pub bias_ff: bool,
|
||||
pub bias_attn: bool,
|
||||
pub layer_scale: Option<f64>,
|
||||
pub positional_embedding: PositionalEmbedding,
|
||||
pub use_conv_block: bool,
|
||||
pub cross_attention: bool,
|
||||
pub conv_kernel_size: usize,
|
||||
pub use_conv_bias: bool,
|
||||
pub gating: Option<candle_nn::Activation>,
|
||||
pub norm: super::NormType,
|
||||
pub context: usize,
|
||||
pub max_period: usize,
|
||||
pub max_seq_len: usize,
|
||||
|
||||
pub kv_repeat: usize,
|
||||
pub dim_feedforward: usize,
|
||||
pub conv_layout: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dim: usize, max_seq_len: usize, theta: f32, dev: &Device) -> Result<Self> {
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
span: tracing::span!(tracing::Level::TRACE, "rot"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn apply_rotary_emb(&self, qk: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_size, _nheads, seqlen, _headdim) = qk.dims4()?;
|
||||
let qk_dtype = qk.dtype();
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
|
||||
candle_nn::rotary_emb::rope_i(&qk.to_dtype(DType::F32)?, &c, &s)?.to_dtype(qk_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LayerScale {
|
||||
scale: Tensor,
|
||||
}
|
||||
|
||||
impl LayerScale {
|
||||
pub fn new(d_model: usize, _init: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let scale = vb.get(d_model, "scale")?;
|
||||
Ok(Self { scale })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&self.scale)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_mask(
|
||||
size1: usize,
|
||||
size2: usize,
|
||||
context: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size1)
|
||||
.flat_map(|i| {
|
||||
(0..size2)
|
||||
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size1, size2), device)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingMultiheadAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
kv_repeat: usize,
|
||||
num_heads: usize,
|
||||
context: usize,
|
||||
neg_inf: Tensor,
|
||||
rope: Option<Arc<RotaryEmbedding>>,
|
||||
kv_cache: candle_nn::kv_cache::KvCache,
|
||||
pos: usize,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamingMultiheadAttention {
|
||||
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let num_kv = cfg.num_heads / cfg.kv_repeat;
|
||||
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
|
||||
let q_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(embed_dim, kv_dim, cfg.bias_attn, vb.pp("v_proj"))?;
|
||||
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("o_proj"))?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
rope: rope.clone(),
|
||||
kv_repeat: cfg.kv_repeat,
|
||||
num_heads: cfg.num_heads,
|
||||
context: cfg.context,
|
||||
neg_inf,
|
||||
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
|
||||
pos: 0,
|
||||
use_flash_attn: false,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
if self.kv_repeat != 1 {
|
||||
candle::bail!("only kv-repeat = 1 is supported")
|
||||
}
|
||||
let (b, t, hd) = xs.dims3()?;
|
||||
let head_dim = hd / self.num_heads;
|
||||
let q = xs
|
||||
.apply(&self.q_proj)?
|
||||
.reshape((b, t, self.num_heads, head_dim))?;
|
||||
let k = xs
|
||||
.apply(&self.k_proj)?
|
||||
.reshape((b, t, self.num_heads, head_dim))?;
|
||||
let v = xs
|
||||
.apply(&self.v_proj)?
|
||||
.reshape((b, t, self.num_heads, head_dim))?;
|
||||
// qk_layer_norm = None
|
||||
// kv_repeat = 1, otherwise we would need repeat_kv
|
||||
let mut q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
|
||||
let mut k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
if let Some(rope) = &self.rope {
|
||||
q = rope.apply_rotary_emb(&q, self.pos)?;
|
||||
k = rope.apply_rotary_emb(&k, self.pos)?;
|
||||
}
|
||||
|
||||
let (k, v) = {
|
||||
self.pos += k.dim(2)?;
|
||||
self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?
|
||||
};
|
||||
// The KV cache keeps all the data at the moment, we want to trim
|
||||
// down the part that comes from the cache to at most context to
|
||||
// be coherent with the mask shape we provide.
|
||||
let k_len = k.dim(2)?;
|
||||
let k_target_len = t + usize::min(self.context, k_len - t);
|
||||
let (k, v) = if k_target_len < k_len {
|
||||
let k = k.narrow(2, k_len - k_target_len, k_target_len)?;
|
||||
let v = v.narrow(2, k_len - k_target_len, k_target_len)?;
|
||||
(k, v)
|
||||
} else {
|
||||
(k.clone(), v.clone())
|
||||
};
|
||||
|
||||
let xs = if q.dtype() == DType::BF16 && self.use_flash_attn {
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, t > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
|
||||
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
|
||||
|
||||
let pre_ws = match mask {
|
||||
None => pre_ws,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_left((b, self.num_heads))?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
||||
mask.where_cond(&neg_inf, &pre_ws)?
|
||||
}
|
||||
};
|
||||
|
||||
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
|
||||
ws.matmul(&v)? // b,h,t,d
|
||||
};
|
||||
let xs = xs
|
||||
.transpose(1, 2)? // b,t,h,d
|
||||
.reshape((b, t, hd))?
|
||||
.apply(&self.out_proj)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache.reset()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
self.kv_cache = kv_cache
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingMultiheadCrossAttention {
|
||||
in_proj_q: Linear,
|
||||
in_proj_k: Linear,
|
||||
in_proj_v: Linear,
|
||||
out_proj: Linear,
|
||||
kv_repeat: usize,
|
||||
num_heads: usize,
|
||||
neg_inf: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamingMultiheadCrossAttention {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_dim = cfg.d_model;
|
||||
let num_kv = cfg.num_heads / cfg.kv_repeat;
|
||||
let kv_dim = num_kv * (embed_dim / cfg.num_heads);
|
||||
let out_dim = embed_dim + 2 * kv_dim;
|
||||
let in_proj_weight = vb.get((out_dim, embed_dim), "in_proj_weight")?;
|
||||
let in_proj_weight_q = in_proj_weight.narrow(0, 0, embed_dim)?;
|
||||
let in_proj_weight_k = in_proj_weight.narrow(0, embed_dim, kv_dim)?;
|
||||
let in_proj_weight_v = in_proj_weight.narrow(0, embed_dim + kv_dim, kv_dim)?;
|
||||
let (in_proj_bias_q, in_proj_bias_k, in_proj_bias_v) = if cfg.bias_attn {
|
||||
let b = vb.get(out_dim, "in_proj_bias")?;
|
||||
let q = b.narrow(0, 0, embed_dim)?;
|
||||
let k = b.narrow(0, embed_dim, kv_dim)?;
|
||||
let v = b.narrow(0, embed_dim + kv_dim, kv_dim)?;
|
||||
(Some(q), Some(k), Some(v))
|
||||
} else {
|
||||
(None, None, None)
|
||||
};
|
||||
let in_proj_q = Linear::new(in_proj_weight_q, in_proj_bias_q);
|
||||
let in_proj_k = Linear::new(in_proj_weight_k, in_proj_bias_k);
|
||||
let in_proj_v = Linear::new(in_proj_weight_v, in_proj_bias_v);
|
||||
let out_proj = linear(embed_dim, embed_dim, cfg.bias_attn, vb.pp("out_proj"))?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?;
|
||||
Ok(Self {
|
||||
in_proj_q,
|
||||
in_proj_k,
|
||||
in_proj_v,
|
||||
out_proj,
|
||||
kv_repeat: cfg.kv_repeat,
|
||||
num_heads: cfg.num_heads,
|
||||
neg_inf,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mhca"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, ca_src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
if self.kv_repeat != 1 {
|
||||
candle::bail!("only kv-repeat = 1 is supported")
|
||||
}
|
||||
let (b, t, hd) = xs.dims3()?;
|
||||
let head_dim = hd / self.num_heads;
|
||||
// time_dim = 1, layout: b,t,h,d
|
||||
let q = xs.apply(&self.in_proj_q)?;
|
||||
let k = ca_src.apply(&self.in_proj_k)?;
|
||||
let v = ca_src.apply(&self.in_proj_v)?;
|
||||
let (ca_b, ca_t, ca_dim) = k.dims3()?;
|
||||
let q = q.reshape((b, t, self.num_heads, head_dim))?;
|
||||
let k = k.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
|
||||
let v = v.reshape((ca_b, ca_t, ca_dim / head_dim, head_dim))?;
|
||||
// qk_layer_norm = None
|
||||
// kv_repeat = 1, otherwise we would need repeat_kv
|
||||
let q = q.transpose(1, 2)?.contiguous()?; // b,h,t,d
|
||||
let k = k.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
let v = v.transpose(1, 2)?.contiguous()?; // b,h,k,d
|
||||
|
||||
let pre_ws = q.matmul(&k.t()?)?; // b,h,t,k
|
||||
let pre_ws = (pre_ws * (head_dim as f64).powf(-0.5))?;
|
||||
|
||||
let pre_ws = match mask {
|
||||
None => pre_ws,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_left((b, self.num_heads))?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
|
||||
mask.where_cond(&neg_inf, &pre_ws)?
|
||||
}
|
||||
};
|
||||
|
||||
let ws = candle_nn::ops::softmax_last_dim(&pre_ws)?; // b,h,t,k
|
||||
let xs = ws.matmul(&v)?; // b,h,t,d
|
||||
let xs = xs
|
||||
.transpose(1, 2)? // b,t,h,d
|
||||
.reshape((b, t, hd))?
|
||||
.apply(&self.out_proj)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Mlp {
|
||||
NoGating {
|
||||
span1: tracing::Span,
|
||||
linear1: Linear,
|
||||
span2: tracing::Span,
|
||||
linear2: Linear,
|
||||
span: tracing::Span,
|
||||
},
|
||||
Gating {
|
||||
linear_in: Linear,
|
||||
linear_out: Linear,
|
||||
activation: candle_nn::Activation,
|
||||
span: tracing::Span,
|
||||
},
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let d_model = cfg.d_model;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
|
||||
match cfg.gating {
|
||||
None => {
|
||||
let span1 = tracing::span!(tracing::Level::TRACE, "lin1");
|
||||
let span2 = tracing::span!(tracing::Level::TRACE, "lin2");
|
||||
let linear1 = linear(d_model, cfg.dim_feedforward, cfg.bias_ff, vb.pp("mlp.fc1"))?;
|
||||
let linear2 = linear(cfg.dim_feedforward, d_model, cfg.bias_ff, vb.pp("mlp.fc2"))?;
|
||||
Ok(Self::NoGating {
|
||||
linear1,
|
||||
linear2,
|
||||
span,
|
||||
span1,
|
||||
span2,
|
||||
})
|
||||
}
|
||||
Some(activation) => {
|
||||
let vb = vb.pp("gating");
|
||||
let hidden = if cfg.dim_feedforward == 4 * d_model {
|
||||
11 * d_model / 4
|
||||
} else {
|
||||
2 * cfg.dim_feedforward / 3
|
||||
};
|
||||
// TODO: Maybe use bias_ff here?
|
||||
let linear_in = linear(d_model, 2 * hidden, false, vb.pp("linear_in"))?;
|
||||
let linear_out = linear(hidden, d_model, false, vb.pp("linear_out"))?;
|
||||
Ok(Self::Gating {
|
||||
linear_in,
|
||||
linear_out,
|
||||
activation,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::NoGating {
|
||||
linear1,
|
||||
linear2,
|
||||
span,
|
||||
span1,
|
||||
span2,
|
||||
} => {
|
||||
let _enter = span.enter();
|
||||
let xs = {
|
||||
let _enter = span1.enter();
|
||||
xs.apply(linear1)?
|
||||
};
|
||||
let xs = xs.gelu_erf()?;
|
||||
{
|
||||
let _enter = span2.enter();
|
||||
xs.apply(linear2)
|
||||
}
|
||||
}
|
||||
Self::Gating {
|
||||
linear_in,
|
||||
linear_out,
|
||||
activation,
|
||||
span,
|
||||
} => {
|
||||
let _enter = span.enter();
|
||||
let xs = xs.apply(linear_in)?;
|
||||
let (b, t, _) = xs.dims3()?;
|
||||
let xs = xs.reshape((b, t, 2, ()))?;
|
||||
let xs = (xs.i((.., .., 0))?.apply(activation)? * xs.i((.., .., 1))?)?;
|
||||
xs.apply(linear_out)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RmsNorm {
|
||||
pub(crate) alpha: Tensor,
|
||||
pub(crate) eps: f32,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
pub fn new(d_model: usize, eps: f32, vb: VarBuilder) -> Result<Self> {
|
||||
let alpha = vb.get((1, 1, d_model), "alpha")?.reshape(d_model)?;
|
||||
Ok(Self { alpha, eps })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
candle_nn::ops::rms_norm(xs, &self.alpha, self.eps)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Norm {
|
||||
LayerNorm(candle_nn::LayerNorm),
|
||||
RmsNorm(RmsNorm),
|
||||
}
|
||||
|
||||
impl Norm {
|
||||
pub fn new(d_model: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let norm = match cfg.norm {
|
||||
super::NormType::LayerNorm => {
|
||||
let norm = candle_nn::layer_norm(d_model, 1e-5, vb)?;
|
||||
Self::LayerNorm(norm)
|
||||
}
|
||||
super::NormType::RmsNorm => {
|
||||
let norm = RmsNorm::new(d_model, 1e-8, vb)?;
|
||||
Self::RmsNorm(norm)
|
||||
}
|
||||
};
|
||||
Ok(norm)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Norm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::LayerNorm(m) => m.forward(xs),
|
||||
Self::RmsNorm(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingTransformerLayer {
|
||||
self_attn: StreamingMultiheadAttention,
|
||||
mlp: Mlp,
|
||||
norm1: Norm,
|
||||
norm2: Norm,
|
||||
layer_scale_1: Option<LayerScale>,
|
||||
layer_scale_2: Option<LayerScale>,
|
||||
cross_attn: Option<(candle_nn::LayerNorm, StreamingMultiheadCrossAttention)>,
|
||||
norm_first: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl StreamingTransformerLayer {
|
||||
pub fn new(rope: &Option<Arc<RotaryEmbedding>>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
if cfg.use_conv_block {
|
||||
candle::bail!("conv-block is not supported")
|
||||
}
|
||||
let d_model = cfg.d_model;
|
||||
let mlp = Mlp::new(cfg, vb.clone())?;
|
||||
let (norm1, norm2) = match cfg.norm {
|
||||
super::NormType::LayerNorm => {
|
||||
let norm1 = candle_nn::layer_norm(d_model, 1e-5, vb.pp("input_layernorm"))?;
|
||||
let norm2 =
|
||||
candle_nn::layer_norm(d_model, 1e-5, vb.pp("post_attention_layernorm"))?;
|
||||
(Norm::LayerNorm(norm1), Norm::LayerNorm(norm2))
|
||||
}
|
||||
super::NormType::RmsNorm => {
|
||||
let norm1 = RmsNorm::new(d_model, 1e-8, vb.pp("input_rmsnorm"))?;
|
||||
let norm2 = RmsNorm::new(d_model, 1e-8, vb.pp("post_attention_rmsnorm"))?;
|
||||
(Norm::RmsNorm(norm1), Norm::RmsNorm(norm2))
|
||||
}
|
||||
};
|
||||
let layer_scale_1 = match cfg.layer_scale {
|
||||
None => None,
|
||||
Some(ls) => {
|
||||
let ls = LayerScale::new(d_model, ls, vb.pp("self_attn_layer_scale"))?;
|
||||
Some(ls)
|
||||
}
|
||||
};
|
||||
let layer_scale_2 = match cfg.layer_scale {
|
||||
None => None,
|
||||
Some(ls) => {
|
||||
let ls = LayerScale::new(d_model, ls, vb.pp("mlp_layer_scale"))?;
|
||||
Some(ls)
|
||||
}
|
||||
};
|
||||
let self_attn = StreamingMultiheadAttention::new(rope, cfg, vb.pp("self_attn"))?;
|
||||
let cross_attn = if cfg.cross_attention {
|
||||
let norm_cross = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_cross"))?;
|
||||
let cross_attn = StreamingMultiheadCrossAttention::new(cfg, vb.pp("cross_attention"))?;
|
||||
Some((norm_cross, cross_attn))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
norm1,
|
||||
norm2,
|
||||
layer_scale_1,
|
||||
layer_scale_2,
|
||||
cross_attn,
|
||||
norm_first: cfg.norm_first,
|
||||
span: tracing::span!(tracing::Level::TRACE, "transformer-layer"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
ca_src: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
if !self.norm_first {
|
||||
candle::bail!("only norm_first = true is supported")
|
||||
}
|
||||
let norm1 = xs.apply(&self.norm1)?;
|
||||
let xs = (xs
|
||||
+ self
|
||||
.self_attn
|
||||
.forward(&norm1, mask)?
|
||||
.apply(&self.layer_scale_1.as_ref())?)?;
|
||||
|
||||
let xs = match (&self.cross_attn, ca_src) {
|
||||
(Some((norm_cross, cross_attn)), Some(ca_src)) => {
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(norm_cross)?;
|
||||
(residual + cross_attn.forward(&xs, ca_src, None)?)?
|
||||
}
|
||||
_ => xs,
|
||||
};
|
||||
|
||||
let xs = (&xs
|
||||
+ xs.apply(&self.norm2)?
|
||||
.apply(&self.mlp)?
|
||||
.apply(&self.layer_scale_2.as_ref()))?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
|
||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
||||
self.self_attn.set_kv_cache(kv_cache)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingTransformer {
|
||||
layers: Vec<StreamingTransformerLayer>,
|
||||
context: usize,
|
||||
positional_embedding: PositionalEmbedding,
|
||||
max_period: usize,
|
||||
}
|
||||
|
||||
impl StreamingTransformer {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_l = vb.pp("layers");
|
||||
let rope = match cfg.positional_embedding {
|
||||
PositionalEmbedding::Rope => {
|
||||
let rope = RotaryEmbedding::new(
|
||||
cfg.d_model / cfg.num_heads,
|
||||
cfg.max_seq_len,
|
||||
cfg.max_period as f32,
|
||||
vb.device(),
|
||||
)?;
|
||||
Some(Arc::new(rope))
|
||||
}
|
||||
PositionalEmbedding::Sin | PositionalEmbedding::None => None,
|
||||
};
|
||||
let mut layers = Vec::with_capacity(cfg.num_layers);
|
||||
for layer_idx in 0..cfg.num_layers {
|
||||
let layer = StreamingTransformerLayer::new(&rope, cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self {
|
||||
layers,
|
||||
context: cfg.context,
|
||||
positional_embedding: cfg.positional_embedding,
|
||||
max_period: cfg.max_period,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward_ca(xs, None)
|
||||
}
|
||||
|
||||
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (_b, t, c) = xs.dims3()?;
|
||||
// We will extract at most "context" from the kv_cache.
|
||||
// Note that the mask will discard the values that are before context.
|
||||
let pos = self.layers[0]
|
||||
.self_attn
|
||||
.kv_cache
|
||||
.k_cache()
|
||||
.current_seq_len()
|
||||
.min(self.context);
|
||||
let mask = if t == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(t, pos + t, self.context, xs.device())?)
|
||||
};
|
||||
let mut xs = match self.positional_embedding {
|
||||
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
||||
PositionalEmbedding::Sin => {
|
||||
let dev = xs.device();
|
||||
let theta = self.max_period as f32;
|
||||
let half_dim = c / 2;
|
||||
let positions = Tensor::arange(pos as u32, (pos + t) as u32, dev)?
|
||||
.unsqueeze(1)?
|
||||
.to_dtype(DType::F32)?;
|
||||
let inv_freq: Vec<_> = (0..half_dim)
|
||||
.map(|i| 1f32 / theta.powf(i as f32 / (half_dim - 1) as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let freqs = positions.broadcast_mul(&inv_freq)?;
|
||||
let pos_emb =
|
||||
Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?.to_dtype(xs.dtype())?;
|
||||
xs.broadcast_add(&pos_emb)?
|
||||
}
|
||||
};
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, ca_src, mask.as_ref())?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn copy_state(&mut self, from: &Self) -> Result<()> {
|
||||
if self.layers.len() != from.layers.len() {
|
||||
candle::bail!("cannot copy kv-caches as the transformers have different depths")
|
||||
}
|
||||
self.layers
|
||||
.iter_mut()
|
||||
.zip(from.layers.iter())
|
||||
.for_each(|(v, w)| v.set_kv_cache(w.self_attn.kv_cache.clone()));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for StreamingTransformer {
|
||||
fn reset_state(&mut self) {
|
||||
self.layers.iter_mut().for_each(|v| v.reset_kv_cache())
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
match xs.as_option() {
|
||||
None => Ok(StreamTensor::empty()),
|
||||
Some(xs) => Ok(StreamTensor::from_tensor(self.forward(xs)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProjectedTransformer {
|
||||
transformer: StreamingTransformer,
|
||||
input_proj: Option<Linear>,
|
||||
output_projs: Vec<Option<Linear>>,
|
||||
conv_layout: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ProjectedTransformer {
|
||||
pub fn new(
|
||||
input_dim: usize,
|
||||
output_dims: &[usize],
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let transformer = StreamingTransformer::new(cfg, vb.clone())?;
|
||||
let input_proj = if input_dim == cfg.d_model {
|
||||
None
|
||||
} else {
|
||||
let l = linear_no_bias(input_dim, cfg.d_model, vb.pp("input_proj"))?;
|
||||
Some(l)
|
||||
};
|
||||
let mut output_projs = Vec::with_capacity(output_dims.len());
|
||||
let vb_o = vb.pp("output_projs");
|
||||
for (i, &output_dim) in output_dims.iter().enumerate() {
|
||||
let output_proj = if output_dim == cfg.d_model {
|
||||
None
|
||||
} else {
|
||||
let l = linear_no_bias(cfg.d_model, output_dim, vb_o.pp(i))?;
|
||||
Some(l)
|
||||
};
|
||||
output_projs.push(output_proj)
|
||||
}
|
||||
Ok(Self {
|
||||
transformer,
|
||||
input_proj,
|
||||
output_projs,
|
||||
conv_layout: cfg.conv_layout,
|
||||
span: tracing::span!(tracing::Level::TRACE, "proj-transformer"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = if self.conv_layout {
|
||||
xs.transpose(1, 2)?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
let xs = xs.apply(&self.input_proj.as_ref())?;
|
||||
let xs = self.transformer.forward(&xs)?;
|
||||
let mut ys = Vec::with_capacity(self.output_projs.len());
|
||||
for output_proj in self.output_projs.iter() {
|
||||
let ys_ = xs.apply(&output_proj.as_ref())?;
|
||||
let ys_ = if self.conv_layout {
|
||||
ys_.transpose(1, 2)?
|
||||
} else {
|
||||
ys_
|
||||
};
|
||||
ys.push(ys_)
|
||||
}
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingModule for ProjectedTransformer {
|
||||
fn reset_state(&mut self) {
|
||||
self.transformer.reset_state()
|
||||
}
|
||||
|
||||
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
|
||||
let xs = xs.apply(&|x: &Tensor| {
|
||||
if self.conv_layout {
|
||||
x.transpose(1, 2)
|
||||
} else {
|
||||
Ok(x.clone())
|
||||
}
|
||||
})?;
|
||||
let xs = xs.apply(&self.input_proj.as_ref())?;
|
||||
let xs = self.transformer.step(&xs)?;
|
||||
let ys = xs.apply(&self.output_projs[0].as_ref())?;
|
||||
ys.apply(&|y: &Tensor| {
|
||||
if self.conv_layout {
|
||||
y.transpose(1, 2)
|
||||
} else {
|
||||
Ok(y.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
294
candle-transformers/src/models/mmdit/blocks.rs
Normal file
294
candle-transformers/src/models/mmdit/blocks.rs
Normal file
@ -0,0 +1,294 @@
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections};
|
||||
|
||||
pub struct ModulateIntermediates {
|
||||
gate_msa: Tensor,
|
||||
shift_mlp: Tensor,
|
||||
scale_mlp: Tensor,
|
||||
gate_mlp: Tensor,
|
||||
}
|
||||
|
||||
pub struct DiTBlock {
|
||||
norm1: LayerNormNoAffine,
|
||||
attn: AttnProjections,
|
||||
norm2: LayerNormNoAffine,
|
||||
mlp: Mlp,
|
||||
ada_ln_modulation: nn::Sequential,
|
||||
}
|
||||
|
||||
pub struct LayerNormNoAffine {
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNormNoAffine {
|
||||
pub fn new(eps: f64) -> Self {
|
||||
Self { eps }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNormNoAffine {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl DiTBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
// {'hidden_size': 1536, 'num_heads': 24}
|
||||
let norm1 = LayerNormNoAffine::new(1e-6);
|
||||
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
|
||||
let norm2 = LayerNormNoAffine::new(1e-6);
|
||||
let mlp_ratio = 4;
|
||||
let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
|
||||
let n_mods = 6;
|
||||
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
|
||||
hidden_size,
|
||||
n_mods * hidden_size,
|
||||
vb.pp("adaLN_modulation.1"),
|
||||
)?);
|
||||
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> {
|
||||
let modulation = self.ada_ln_modulation.forward(c)?;
|
||||
let chunks = modulation.chunk(6, D::Minus1)?;
|
||||
let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
|
||||
chunks[0].clone(),
|
||||
chunks[1].clone(),
|
||||
chunks[2].clone(),
|
||||
chunks[3].clone(),
|
||||
chunks[4].clone(),
|
||||
chunks[5].clone(),
|
||||
);
|
||||
|
||||
let norm_x = self.norm1.forward(x)?;
|
||||
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
|
||||
let qkv = self.attn.pre_attention(&modulated_x)?;
|
||||
|
||||
Ok((
|
||||
qkv,
|
||||
ModulateIntermediates {
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
pub fn post_attention(
|
||||
&self,
|
||||
attn: &Tensor,
|
||||
x: &Tensor,
|
||||
mod_interm: &ModulateIntermediates,
|
||||
) -> Result<Tensor> {
|
||||
let attn_out = self.attn.post_attention(attn)?;
|
||||
let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
|
||||
|
||||
let norm_x = self.norm2.forward(&x)?;
|
||||
let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
|
||||
let mlp_out = self.mlp.forward(&modulated_x)?;
|
||||
let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
|
||||
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QkvOnlyDiTBlock {
|
||||
norm1: LayerNormNoAffine,
|
||||
attn: QkvOnlyAttnProjections,
|
||||
ada_ln_modulation: nn::Sequential,
|
||||
}
|
||||
|
||||
impl QkvOnlyDiTBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let norm1 = LayerNormNoAffine::new(1e-6);
|
||||
let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
|
||||
let n_mods = 2;
|
||||
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
|
||||
hidden_size,
|
||||
n_mods * hidden_size,
|
||||
vb.pp("adaLN_modulation.1"),
|
||||
)?);
|
||||
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<Qkv> {
|
||||
let modulation = self.ada_ln_modulation.forward(c)?;
|
||||
let chunks = modulation.chunk(2, D::Minus1)?;
|
||||
let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone());
|
||||
|
||||
let norm_x = self.norm1.forward(x)?;
|
||||
let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
|
||||
self.attn.pre_attention(&modulated_x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FinalLayer {
|
||||
norm_final: LayerNormNoAffine,
|
||||
linear: nn::Linear,
|
||||
ada_ln_modulation: nn::Sequential,
|
||||
}
|
||||
|
||||
impl FinalLayer {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm_final = LayerNormNoAffine::new(1e-6);
|
||||
let linear = nn::linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
vb.pp("linear"),
|
||||
)?;
|
||||
let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
|
||||
hidden_size,
|
||||
2 * hidden_size,
|
||||
vb.pp("adaLN_modulation.1"),
|
||||
)?);
|
||||
|
||||
Ok(Self {
|
||||
norm_final,
|
||||
linear,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let modulation = self.ada_ln_modulation.forward(c)?;
|
||||
let chunks = modulation.chunk(2, D::Minus1)?;
|
||||
let (shift, scale) = (chunks[0].clone(), chunks[1].clone());
|
||||
|
||||
let norm_x = self.norm_final.forward(x)?;
|
||||
let modulated_x = modulate(&norm_x, &shift, &scale)?;
|
||||
let output = self.linear.forward(&modulated_x)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
|
||||
let shift = shift.unsqueeze(1)?;
|
||||
let scale = scale.unsqueeze(1)?;
|
||||
let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?;
|
||||
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
|
||||
}
|
||||
|
||||
pub struct JointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: DiTBlock,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl JointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
let context_out =
|
||||
self.context_block
|
||||
.post_attention(&context_attn, context, &context_interm)?;
|
||||
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
|
||||
Ok((context_out, x_out))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContextQkvOnlyJointBlock {
|
||||
x_block: DiTBlock,
|
||||
context_block: QkvOnlyDiTBlock,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl ContextQkvOnlyJointBlock {
|
||||
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
|
||||
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
|
||||
Ok(Self {
|
||||
x_block,
|
||||
context_block,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let context_qkv = self.context_block.pre_attention(context, c)?;
|
||||
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
|
||||
|
||||
let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
|
||||
|
||||
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
|
||||
Ok(x_out)
|
||||
}
|
||||
}
|
||||
|
||||
// A QKV-attention that is compatible with the interface of candle_flash_attn::flash_attn
|
||||
// Flash attention regards q, k, v dimensions as (batch_size, seqlen, nheads, headdim)
|
||||
fn flash_compatible_attention(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
) -> Result<Tensor> {
|
||||
let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec();
|
||||
let rank = q_dims_for_matmul.len();
|
||||
let q = q.transpose(1, 2)?.flatten_to(rank - 3)?;
|
||||
let k = k.transpose(1, 2)?.flatten_to(rank - 3)?;
|
||||
let v = v.transpose(1, 2)?.flatten_to(rank - 3)?;
|
||||
let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
|
||||
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
|
||||
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
|
||||
let qkv = Qkv {
|
||||
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
|
||||
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
|
||||
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
|
||||
};
|
||||
|
||||
let (batch_size, seqlen, _) = qkv.q.dims3()?;
|
||||
let qkv = Qkv {
|
||||
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
|
||||
v: qkv.v,
|
||||
};
|
||||
|
||||
let headdim = qkv.q.dim(D::Minus1)?;
|
||||
let softmax_scale = 1.0 / (headdim as f64).sqrt();
|
||||
// let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
|
||||
let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
|
||||
|
||||
let attn = attn.reshape((batch_size, seqlen, ()))?;
|
||||
let context_qkv_seqlen = context_qkv.q.dim(1)?;
|
||||
let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
|
||||
let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
|
||||
|
||||
Ok((context_attn, x_attn))
|
||||
}
|
197
candle-transformers/src/models/mmdit/embedding.rs
Normal file
197
candle-transformers/src/models/mmdit/embedding.rs
Normal file
@ -0,0 +1,197 @@
|
||||
use candle::{bail, DType, Module, Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
pub struct PatchEmbedder {
|
||||
proj: nn::Conv2d,
|
||||
}
|
||||
|
||||
impl PatchEmbedder {
|
||||
pub fn new(
|
||||
patch_size: usize,
|
||||
in_channels: usize,
|
||||
embed_dim: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let proj = nn::conv2d(
|
||||
in_channels,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("proj"),
|
||||
)?;
|
||||
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbedder {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.proj.forward(x)?;
|
||||
|
||||
// flatten spatial dim and transpose to channels last
|
||||
let (b, c, h, w) = x.dims4()?;
|
||||
x.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Unpatchifier {
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
}
|
||||
|
||||
impl Unpatchifier {
|
||||
pub fn new(patch_size: usize, out_channels: usize) -> Result<Self> {
|
||||
Ok(Self {
|
||||
patch_size,
|
||||
out_channels,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result<Tensor> {
|
||||
let h = (h + 1) / self.patch_size;
|
||||
let w = (w + 1) / self.patch_size;
|
||||
|
||||
let x = x.reshape((
|
||||
x.dim(0)?,
|
||||
h,
|
||||
w,
|
||||
self.patch_size,
|
||||
self.patch_size,
|
||||
self.out_channels,
|
||||
))?;
|
||||
let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq"
|
||||
x.reshape((
|
||||
x.dim(0)?,
|
||||
self.out_channels,
|
||||
self.patch_size * h,
|
||||
self.patch_size * w,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PositionEmbedder {
|
||||
pos_embed: Tensor,
|
||||
patch_size: usize,
|
||||
pos_embed_max_size: usize,
|
||||
}
|
||||
|
||||
impl PositionEmbedder {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
patch_size: usize,
|
||||
pos_embed_max_size: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let pos_embed = vb.get(
|
||||
(1, pos_embed_max_size * pos_embed_max_size, hidden_size),
|
||||
"pos_embed",
|
||||
)?;
|
||||
Ok(Self {
|
||||
pos_embed,
|
||||
patch_size,
|
||||
pos_embed_max_size,
|
||||
})
|
||||
}
|
||||
pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||
let h = (h + 1) / self.patch_size;
|
||||
let w = (w + 1) / self.patch_size;
|
||||
|
||||
if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
|
||||
bail!("Input size is too large for the position embedding")
|
||||
}
|
||||
|
||||
let top = (self.pos_embed_max_size - h) / 2;
|
||||
let left = (self.pos_embed_max_size - w) / 2;
|
||||
|
||||
let pos_embed =
|
||||
self.pos_embed
|
||||
.reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
|
||||
let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
|
||||
pos_embed.reshape((1, h * w, ()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TimestepEmbedder {
|
||||
mlp: nn::Sequential,
|
||||
frequency_embedding_size: usize,
|
||||
}
|
||||
|
||||
impl TimestepEmbedder {
|
||||
pub fn new(
|
||||
hidden_size: usize,
|
||||
frequency_embedding_size: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mlp = nn::seq()
|
||||
.add(nn::linear(
|
||||
frequency_embedding_size,
|
||||
hidden_size,
|
||||
vb.pp("mlp.0"),
|
||||
)?)
|
||||
.add(nn::Activation::Silu)
|
||||
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
|
||||
|
||||
Ok(Self {
|
||||
mlp,
|
||||
frequency_embedding_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result<Tensor> {
|
||||
if dim % 2 != 0 {
|
||||
bail!("Embedding dimension must be even")
|
||||
}
|
||||
|
||||
if t.dtype() != DType::F32 && t.dtype() != DType::F64 {
|
||||
bail!("Input tensor must be floating point")
|
||||
}
|
||||
|
||||
let half = dim / 2;
|
||||
let freqs = Tensor::arange(0f32, half as f32, t.device())?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.mul(&Tensor::full(
|
||||
(-f64::ln(max_period) / half as f64) as f32,
|
||||
half,
|
||||
t.device(),
|
||||
)?)?
|
||||
.exp()?;
|
||||
|
||||
let args = t
|
||||
.unsqueeze(1)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.matmul(&freqs.unsqueeze(0)?)?;
|
||||
let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
|
||||
embedding.to_dtype(candle::DType::F16)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TimestepEmbedder {
|
||||
fn forward(&self, t: &Tensor) -> Result<Tensor> {
|
||||
let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?;
|
||||
self.mlp.forward(&t_freq)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorEmbedder {
|
||||
mlp: nn::Sequential,
|
||||
}
|
||||
|
||||
impl VectorEmbedder {
|
||||
pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let mlp = nn::seq()
|
||||
.add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?)
|
||||
.add(nn::Activation::Silu)
|
||||
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
|
||||
|
||||
Ok(Self { mlp })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VectorEmbedder {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.mlp.forward(x)
|
||||
}
|
||||
}
|
4
candle-transformers/src/models/mmdit/mod.rs
Normal file
4
candle-transformers/src/models/mmdit/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod blocks;
|
||||
pub mod embedding;
|
||||
pub mod model;
|
||||
pub mod projections;
|
173
candle-transformers/src/models/mmdit/model.rs
Normal file
173
candle-transformers/src/models/mmdit/model.rs
Normal file
@ -0,0 +1,173 @@
|
||||
// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
|
||||
// This follows the implementation of the MMDiT model in the ComfyUI repository.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
|
||||
use super::embedding::{
|
||||
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub patch_size: usize,
|
||||
pub in_channels: usize,
|
||||
pub out_channels: usize,
|
||||
pub depth: usize,
|
||||
pub head_size: usize,
|
||||
pub adm_in_channels: usize,
|
||||
pub pos_embed_max_size: usize,
|
||||
pub context_embed_size: usize,
|
||||
pub frequency_embedding_size: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn sd3() -> Self {
|
||||
Self {
|
||||
patch_size: 2,
|
||||
in_channels: 16,
|
||||
out_channels: 16,
|
||||
depth: 24,
|
||||
head_size: 64,
|
||||
adm_in_channels: 2048,
|
||||
pos_embed_max_size: 192,
|
||||
context_embed_size: 4096,
|
||||
frequency_embedding_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMDiT {
|
||||
core: MMDiTCore,
|
||||
patch_embedder: PatchEmbedder,
|
||||
pos_embedder: PositionEmbedder,
|
||||
timestep_embedder: TimestepEmbedder,
|
||||
vector_embedder: VectorEmbedder,
|
||||
context_embedder: nn::Linear,
|
||||
unpatchifier: Unpatchifier,
|
||||
}
|
||||
|
||||
impl MMDiT {
|
||||
pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.head_size * cfg.depth;
|
||||
let core = MMDiTCore::new(
|
||||
cfg.depth,
|
||||
hidden_size,
|
||||
cfg.depth,
|
||||
cfg.patch_size,
|
||||
cfg.out_channels,
|
||||
vb.clone(),
|
||||
)?;
|
||||
let patch_embedder = PatchEmbedder::new(
|
||||
cfg.patch_size,
|
||||
cfg.in_channels,
|
||||
hidden_size,
|
||||
vb.pp("x_embedder"),
|
||||
)?;
|
||||
let pos_embedder = PositionEmbedder::new(
|
||||
hidden_size,
|
||||
cfg.patch_size,
|
||||
cfg.pos_embed_max_size,
|
||||
vb.clone(),
|
||||
)?;
|
||||
let timestep_embedder = TimestepEmbedder::new(
|
||||
hidden_size,
|
||||
cfg.frequency_embedding_size,
|
||||
vb.pp("t_embedder"),
|
||||
)?;
|
||||
let vector_embedder =
|
||||
VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
|
||||
let context_embedder = nn::linear(
|
||||
cfg.context_embed_size,
|
||||
hidden_size,
|
||||
vb.pp("context_embedder"),
|
||||
)?;
|
||||
let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
|
||||
|
||||
Ok(Self {
|
||||
core,
|
||||
patch_embedder,
|
||||
pos_embedder,
|
||||
timestep_embedder,
|
||||
vector_embedder,
|
||||
context_embedder,
|
||||
unpatchifier,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
|
||||
// Following the convention of the ComfyUI implementation.
|
||||
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
|
||||
//
|
||||
// Forward pass of DiT.
|
||||
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
// t: (N,) tensor of diffusion timesteps
|
||||
// y: (N,) tensor of class labels
|
||||
let h = x.dim(D::Minus2)?;
|
||||
let w = x.dim(D::Minus1)?;
|
||||
let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
|
||||
let x = self
|
||||
.patch_embedder
|
||||
.forward(x)?
|
||||
.broadcast_add(&cropped_pos_embed)?;
|
||||
let c = self.timestep_embedder.forward(t)?;
|
||||
let y = self.vector_embedder.forward(y)?;
|
||||
let c = (c + y)?;
|
||||
let context = self.context_embedder.forward(context)?;
|
||||
|
||||
let x = self.core.forward(&context, &x, &c)?;
|
||||
let x = self.unpatchifier.unpatchify(&x, h, w)?;
|
||||
x.narrow(2, 0, h)?.narrow(3, 0, w)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMDiTCore {
|
||||
joint_blocks: Vec<JointBlock>,
|
||||
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
|
||||
final_layer: FinalLayer,
|
||||
}
|
||||
|
||||
impl MMDiTCore {
|
||||
pub fn new(
|
||||
depth: usize,
|
||||
hidden_size: usize,
|
||||
num_heads: usize,
|
||||
patch_size: usize,
|
||||
out_channels: usize,
|
||||
vb: nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let mut joint_blocks = Vec::with_capacity(depth - 1);
|
||||
for i in 0..depth - 1 {
|
||||
joint_blocks.push(JointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
vb.pp(format!("joint_blocks.{}", i)),
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
joint_blocks,
|
||||
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
vb.pp(format!("joint_blocks.{}", depth - 1)),
|
||||
)?,
|
||||
final_layer: FinalLayer::new(
|
||||
hidden_size,
|
||||
patch_size,
|
||||
out_channels,
|
||||
vb.pp("final_layer"),
|
||||
)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
let (mut context, mut x) = (context.clone(), x.clone());
|
||||
for joint_block in &self.joint_blocks {
|
||||
(context, x) = joint_block.forward(&context, &x, c)?;
|
||||
}
|
||||
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
|
||||
self.final_layer.forward(&x, c)
|
||||
}
|
||||
}
|
94
candle-transformers/src/models/mmdit/projections.rs
Normal file
94
candle-transformers/src/models/mmdit/projections.rs
Normal file
@ -0,0 +1,94 @@
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn as nn;
|
||||
|
||||
pub struct Qkv {
|
||||
pub q: Tensor,
|
||||
pub k: Tensor,
|
||||
pub v: Tensor,
|
||||
}
|
||||
|
||||
pub struct Mlp {
|
||||
fc1: nn::Linear,
|
||||
act: nn::Activation,
|
||||
fc2: nn::Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
pub fn new(
|
||||
in_features: usize,
|
||||
hidden_features: usize,
|
||||
vb: candle_nn::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
|
||||
let act = nn::Activation::GeluPytorchTanh;
|
||||
let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Self { fc1, act, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.fc1.forward(x)?;
|
||||
let x = self.act.forward(&x)?;
|
||||
self.fc2.forward(&x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QkvOnlyAttnProjections {
|
||||
qkv: nn::Linear,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl QkvOnlyAttnProjections {
|
||||
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
// {'dim': 1536, 'num_heads': 24}
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
Ok(Self { qkv, head_dim })
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
split_qkv(&qkv, self.head_dim)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AttnProjections {
|
||||
head_dim: usize,
|
||||
qkv: nn::Linear,
|
||||
proj: nn::Linear,
|
||||
}
|
||||
|
||||
impl AttnProjections {
|
||||
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
|
||||
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
|
||||
Ok(Self {
|
||||
head_dim,
|
||||
qkv,
|
||||
proj,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
|
||||
let qkv = self.qkv.forward(x)?;
|
||||
split_qkv(&qkv, self.head_dim)
|
||||
}
|
||||
|
||||
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.proj.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result<Qkv> {
|
||||
let (batch_size, seq_len, _) = qkv.dims3()?;
|
||||
let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
|
||||
let q = qkv.get_on_dim(2, 0)?;
|
||||
let q = q.reshape((batch_size, seq_len, ()))?;
|
||||
let k = qkv.get_on_dim(2, 1)?;
|
||||
let k = k.reshape((batch_size, seq_len, ()))?;
|
||||
let v = qkv.get_on_dim(2, 2)?;
|
||||
Ok(Qkv { q, k, v })
|
||||
}
|
89
candle-transformers/src/models/mobileclip.rs
Normal file
89
candle-transformers/src/models/mobileclip.rs
Normal file
@ -0,0 +1,89 @@
|
||||
use super::fastvit;
|
||||
use super::openclip::text_model;
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{Func, VarBuilder};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MobileClipModel {
|
||||
text_model: text_model::OpenClipTextTransformer,
|
||||
vision_model: Func<'static>,
|
||||
text_projection: Tensor,
|
||||
logit_scale: Tensor,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MobileClipConfig {
|
||||
pub text_config: text_model::Config,
|
||||
pub vision_config: fastvit::Config,
|
||||
pub image_size: usize,
|
||||
}
|
||||
|
||||
impl MobileClipConfig {
|
||||
pub fn s1() -> Self {
|
||||
let text_config = text_model::Config::vit_base_patch32();
|
||||
let vision_config = fastvit::Config::mci1();
|
||||
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
image_size: 256,
|
||||
}
|
||||
}
|
||||
pub fn s2() -> Self {
|
||||
let text_config = text_model::Config::vit_base_patch32();
|
||||
let vision_config = fastvit::Config::mci2();
|
||||
|
||||
Self {
|
||||
text_config,
|
||||
vision_config,
|
||||
image_size: 256,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MobileClipModel {
|
||||
pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result<Self> {
|
||||
let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?;
|
||||
let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?;
|
||||
|
||||
let text_projection = vs.get(
|
||||
(c.text_config.embed_dim, c.text_config.projection_dim),
|
||||
"text.text_projection",
|
||||
)?;
|
||||
|
||||
let logit_scale = vs.get(&[], "logit_scale")?;
|
||||
Ok(Self {
|
||||
text_model,
|
||||
vision_model,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
input_ids
|
||||
.apply(&self.text_model)?
|
||||
.matmul(&self.text_projection)
|
||||
}
|
||||
|
||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
pixel_values.apply(&self.vision_model)
|
||||
}
|
||||
|
||||
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let image_features = self.get_image_features(pixel_values)?;
|
||||
let text_features = self.get_text_features(input_ids)?;
|
||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||
let logit_scale = self.logit_scale.exp()?;
|
||||
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||
let logits_per_image = logits_per_text.t()?;
|
||||
Ok((logits_per_text, logits_per_image))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
|
||||
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
|
||||
v.broadcast_div(&l2_norm)
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
pub mod based;
|
||||
pub mod beit;
|
||||
pub mod bert;
|
||||
pub mod bigcode;
|
||||
@ -8,6 +9,7 @@ pub mod clip;
|
||||
pub mod codegeex4_9b;
|
||||
pub mod convmixer;
|
||||
pub mod convnext;
|
||||
pub mod dac;
|
||||
pub mod depth_anything_v2;
|
||||
pub mod dinov2;
|
||||
pub mod dinov2reg4;
|
||||
@ -17,8 +19,12 @@ pub mod efficientvit;
|
||||
pub mod encodec;
|
||||
pub mod eva2;
|
||||
pub mod falcon;
|
||||
pub mod fastvit;
|
||||
pub mod flux;
|
||||
pub mod gemma;
|
||||
pub mod gemma2;
|
||||
pub mod glm4;
|
||||
pub mod granite;
|
||||
pub mod hiera;
|
||||
pub mod jina_bert;
|
||||
pub mod llama;
|
||||
@ -28,14 +34,19 @@ pub mod llava;
|
||||
pub mod mamba;
|
||||
pub mod marian;
|
||||
pub mod metavoice;
|
||||
pub mod mimi;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mmdit;
|
||||
pub mod mobileclip;
|
||||
pub mod mobilenetv4;
|
||||
pub mod mobileone;
|
||||
pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod olmo;
|
||||
pub mod openclip;
|
||||
pub mod parler_tts;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
pub mod phi3;
|
||||
|
@ -167,7 +167,7 @@ impl VisionTransformer {
|
||||
let blocks = (0..cfg.num_blocks)
|
||||
.map(|i| {
|
||||
VitBlock::new(
|
||||
vb.pp(&format!("blocks.{}", i)),
|
||||
vb.pp(format!("blocks.{}", i)),
|
||||
cfg.embed_dim,
|
||||
cfg.num_heads,
|
||||
cfg,
|
||||
|
1
candle-transformers/src/models/openclip/mod.rs
Normal file
1
candle-transformers/src/models/openclip/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod text_model;
|
266
candle-transformers/src/models/openclip/text_model.rs
Normal file
266
candle-transformers/src/models/openclip/text_model.rs
Normal file
@ -0,0 +1,266 @@
|
||||
//! Text encoder as used in most OpenCLIP pretrained models
|
||||
//! https://github.com/mlfoundations/open_clip
|
||||
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear, ops::softmax_last_dim, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub embed_dim: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub pad_with: Option<String>,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub projection_dim: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn vit_base_patch32() -> Self {
|
||||
Self {
|
||||
vocab_size: 49408,
|
||||
embed_dim: 512,
|
||||
intermediate_size: 2048,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: None,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 8,
|
||||
projection_dim: 512,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TextEmbeddings {
|
||||
token_embedding: Embedding,
|
||||
position_embedding: Tensor,
|
||||
}
|
||||
|
||||
impl TextEmbeddings {
|
||||
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||
let token_embedding = embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||
let position_embedding = vs.get(
|
||||
(c.max_position_embeddings, c.embed_dim),
|
||||
"positional_embedding",
|
||||
)?;
|
||||
Ok(TextEmbeddings {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextEmbeddings {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let seq_length = input_ids.dim(D::Minus1)?;
|
||||
let inputs_embeds = self.token_embedding.forward(input_ids)?;
|
||||
|
||||
let position_embedding = self.position_embedding.narrow(0, 0, seq_length)?;
|
||||
|
||||
inputs_embeds.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Attention {
|
||||
k_proj: candle_nn::Linear,
|
||||
v_proj: candle_nn::Linear,
|
||||
q_proj: candle_nn::Linear,
|
||||
out_proj: Linear,
|
||||
head_dim: usize,
|
||||
scale: f64,
|
||||
num_attention_heads: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
|
||||
let embed_dim = c.embed_dim;
|
||||
let num_attention_heads = c.num_attention_heads;
|
||||
|
||||
let in_proj_weights = vs
|
||||
.get((embed_dim * 3, embed_dim), "in_proj_weight")?
|
||||
.chunk(3, 0)?;
|
||||
let (q_w, k_w, v_w) = (
|
||||
&in_proj_weights[0],
|
||||
&in_proj_weights[1],
|
||||
&in_proj_weights[2],
|
||||
);
|
||||
let in_proj_biases = vs.get(embed_dim * 3, "in_proj_bias")?.chunk(3, 0)?;
|
||||
let (q_b, k_b, v_b) = (&in_proj_biases[0], &in_proj_biases[1], &in_proj_biases[2]);
|
||||
|
||||
let q_proj = Linear::new(q_w.clone(), Some(q_b.clone()));
|
||||
let k_proj = Linear::new(k_w.clone(), Some(k_b.clone()));
|
||||
let v_proj = Linear::new(v_w.clone(), Some(v_b.clone()));
|
||||
let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
|
||||
let head_dim = embed_dim / num_attention_heads;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
|
||||
Ok(Attention {
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
head_dim,
|
||||
scale,
|
||||
num_attention_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn shape_multihead(&self, xs: &Tensor, bsz: usize, seq_len: usize) -> Result<Tensor> {
|
||||
xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?
|
||||
.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let in_dtype = xs.dtype();
|
||||
let (bsz, seq_len, embed_dim) = xs.dims3()?;
|
||||
|
||||
let q = self.shape_multihead(&self.q_proj.forward(xs)?, bsz, seq_len)?;
|
||||
let k = self.shape_multihead(&self.k_proj.forward(xs)?, bsz, seq_len)?;
|
||||
let v = self.shape_multihead(&self.v_proj.forward(xs)?, bsz, seq_len)?;
|
||||
let q = (q * self.scale)?;
|
||||
|
||||
let attn_weights = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
|
||||
|
||||
let attn_weights = softmax_last_dim(&attn_weights)?;
|
||||
|
||||
let attn_output = attn_weights.matmul(&v)?.to_dtype(in_dtype)?;
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?
|
||||
.reshape((bsz, seq_len, embed_dim))?;
|
||||
let out = self.out_proj.forward(&attn_output)?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||
let fc1 = linear(c.embed_dim, c.intermediate_size, vs.pp("c_fc"))?;
|
||||
let fc2 = linear(c.intermediate_size, c.embed_dim, vs.pp("c_proj"))?;
|
||||
|
||||
Ok(Mlp { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?;
|
||||
self.fc2.forward(&xs.gelu_erf()?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct EncoderLayer {
|
||||
self_attn: Attention,
|
||||
layer_norm1: LayerNorm,
|
||||
mlp: Mlp,
|
||||
layer_norm2: LayerNorm,
|
||||
}
|
||||
|
||||
impl EncoderLayer {
|
||||
fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||
let self_attn = Attention::new(vs.pp("attn"), c)?;
|
||||
let layer_norm1 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_1"))?;
|
||||
let mlp = Mlp::new(vs.pp("mlp"), c)?;
|
||||
let layer_norm2 = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_2"))?;
|
||||
|
||||
Ok(EncoderLayer {
|
||||
self_attn,
|
||||
layer_norm1,
|
||||
mlp,
|
||||
layer_norm2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.layer_norm1.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs)?;
|
||||
let xs = (xs + residual)?;
|
||||
|
||||
let residual = &xs;
|
||||
let xs = self.layer_norm2.forward(&xs)?;
|
||||
let xs = self.mlp.forward(&xs)?;
|
||||
let out = (xs + residual)?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Encoder {
|
||||
layers: Vec<EncoderLayer>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||
let vs = vs.pp("resblocks");
|
||||
let mut layers: Vec<EncoderLayer> = Vec::new();
|
||||
for index in 0..c.num_hidden_layers {
|
||||
let layer = EncoderLayer::new(vs.pp(index.to_string()), c)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Encoder { layers })
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
/// A text transformer as used in CLIP variants.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenClipTextTransformer {
|
||||
embeddings: TextEmbeddings,
|
||||
encoder: Encoder,
|
||||
final_layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl OpenClipTextTransformer {
|
||||
pub fn new(vs: VarBuilder, c: &Config) -> Result<Self> {
|
||||
let embeddings = TextEmbeddings::new(vs.clone(), c)?;
|
||||
let final_layer_norm = layer_norm(c.embed_dim, 1e-5, vs.pp("ln_final"))?;
|
||||
let encoder = Encoder::new(vs.pp("transformer"), c)?;
|
||||
Ok(OpenClipTextTransformer {
|
||||
embeddings,
|
||||
encoder,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_ids = self.embeddings.forward(input_ids)?;
|
||||
let input_ids = self.encoder.forward(&input_ids)?;
|
||||
self.final_layer_norm.forward(&input_ids)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for OpenClipTextTransformer {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let output = self.forward(input_ids)?;
|
||||
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
|
||||
|
||||
let mut indices = Vec::new();
|
||||
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
|
||||
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
|
||||
indices.push(index);
|
||||
}
|
||||
Tensor::cat(&indices, 0)
|
||||
}
|
||||
}
|
456
candle-transformers/src/models/parler_tts.rs
Normal file
456
candle-transformers/src/models/parler_tts.rs
Normal file
@ -0,0 +1,456 @@
|
||||
use crate::generation::LogitsProcessor;
|
||||
use crate::models::t5;
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{layer_norm, linear_b as linear, Activation, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct DecoderConfig {
|
||||
pub vocab_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub ffn_dim: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub num_cross_attention_key_value_heads: Option<usize>,
|
||||
pub activation_function: Activation,
|
||||
pub hidden_size: usize,
|
||||
pub scale_embedding: bool,
|
||||
pub num_codebooks: usize,
|
||||
pub pad_token_id: usize,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub rope_embeddings: bool,
|
||||
pub rope_theta: f64,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub decoder_start_token_id: u32,
|
||||
pub pad_token_id: u32,
|
||||
pub decoder: DecoderConfig,
|
||||
pub text_encoder: t5::Config,
|
||||
pub vocab_size: usize,
|
||||
pub audio_encoder: crate::models::dac::Config,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Attention {
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
q_proj: Linear,
|
||||
out_proj: Linear,
|
||||
is_causal: bool,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
scaling: f64,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
num_kv_heads: usize,
|
||||
is_causal: bool,
|
||||
cfg: &DecoderConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if cfg.rope_embeddings {
|
||||
candle::bail!("rope embeddings are not supported");
|
||||
}
|
||||
let embed_dim = cfg.hidden_size;
|
||||
let head_dim = embed_dim / cfg.num_attention_heads;
|
||||
let kv_out_dim = num_kv_heads * head_dim;
|
||||
let k_proj = linear(embed_dim, kv_out_dim, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(embed_dim, kv_out_dim, false, vb.pp("v_proj"))?;
|
||||
let q_proj = linear(embed_dim, embed_dim, false, vb.pp("q_proj"))?;
|
||||
let out_proj = linear(embed_dim, embed_dim, false, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
k_proj,
|
||||
v_proj,
|
||||
q_proj,
|
||||
out_proj,
|
||||
is_causal,
|
||||
kv_cache: None,
|
||||
scaling: (head_dim as f64).powf(-0.5),
|
||||
num_heads: cfg.num_attention_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups: cfg.num_attention_heads / num_kv_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
key_value_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?
|
||||
.reshape((b_sz, tgt_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let key_states = match key_value_states {
|
||||
Some(states) => states.apply(&self.k_proj)?,
|
||||
None => xs.apply(&self.k_proj)?,
|
||||
};
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let value_states = match key_value_states {
|
||||
Some(states) => states.apply(&self.v_proj)?,
|
||||
None => xs.apply(&self.v_proj)?,
|
||||
};
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, (), self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
if self.is_causal {
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
}
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, tgt_len, ()))?
|
||||
.apply(&self.out_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
encoder_attn: Attention,
|
||||
encoder_attn_layer_norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
activation: Activation,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let kv_heads = cfg.num_key_value_heads.unwrap_or(cfg.num_attention_heads);
|
||||
let kv_heads_cross = cfg.num_cross_attention_key_value_heads.unwrap_or(kv_heads);
|
||||
|
||||
let self_attn = Attention::new(kv_heads, true, cfg, vb.pp("self_attn"))?;
|
||||
let encoder_attn = Attention::new(kv_heads_cross, false, cfg, vb.pp("encoder_attn"))?;
|
||||
let self_attn_layer_norm =
|
||||
layer_norm(cfg.hidden_size, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(cfg.hidden_size, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear(cfg.hidden_size, cfg.ffn_dim, false, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.ffn_dim, cfg.hidden_size, false, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
activation: cfg.activation_function,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_xs: &Tensor,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
// Self attention
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.self_attn_layer_norm)?;
|
||||
let xs = self.self_attn.forward(&xs, None, attention_mask)?;
|
||||
let xs = (residual + xs)?;
|
||||
|
||||
// Cross attention
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.encoder_attn_layer_norm)?;
|
||||
let xs = self
|
||||
.encoder_attn
|
||||
.forward(&xs, Some(encoder_xs), encoder_attention_mask)?;
|
||||
let xs = (residual + xs)?;
|
||||
|
||||
// Fully connected
|
||||
let residual = &xs;
|
||||
let xs = xs
|
||||
.apply(&self.final_layer_norm)?
|
||||
.apply(&self.fc1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.fc2)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache();
|
||||
self.encoder_attn.clear_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Decoder {
|
||||
embed_tokens: Vec<candle_nn::Embedding>,
|
||||
embed_positions: Tensor,
|
||||
layers: Vec<DecoderLayer>,
|
||||
layer_norm: LayerNorm,
|
||||
num_codebooks: usize,
|
||||
hidden_size: usize,
|
||||
lm_heads: Vec<Linear>,
|
||||
dtype: candle::DType,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
pub fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_d = vb.pp("model.decoder");
|
||||
let mut embed_tokens = Vec::with_capacity(cfg.num_codebooks);
|
||||
let vb_e = vb_d.pp("embed_tokens");
|
||||
for embed_idx in 0..cfg.num_codebooks {
|
||||
let e = candle_nn::embedding(cfg.vocab_size + 1, cfg.hidden_size, vb_e.pp(embed_idx))?;
|
||||
embed_tokens.push(e)
|
||||
}
|
||||
let embed_positions = vb_d.get(
|
||||
(cfg.max_position_embeddings, cfg.hidden_size),
|
||||
"embed_positions.weights",
|
||||
)?;
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_d.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb_d.pp("layer_norm"))?;
|
||||
|
||||
let mut lm_heads = Vec::with_capacity(cfg.num_codebooks);
|
||||
let vb_l = vb.pp("lm_heads");
|
||||
for lm_idx in 0..cfg.num_codebooks {
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb_l.pp(lm_idx))?;
|
||||
lm_heads.push(lm_head)
|
||||
}
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
layers,
|
||||
layer_norm,
|
||||
num_codebooks: cfg.num_codebooks,
|
||||
lm_heads,
|
||||
hidden_size: cfg.hidden_size,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
prompt_hidden_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_xs: &Tensor,
|
||||
encoder_attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Vec<Tensor>> {
|
||||
let (b_sz, num_codebooks, seq_len) = input_ids.dims3()?;
|
||||
if num_codebooks != self.num_codebooks {
|
||||
candle::bail!("unexpected num codebooks in input {:?}", input_ids.shape())
|
||||
}
|
||||
let mut inputs_embeds = Tensor::zeros(
|
||||
(b_sz, seq_len, self.hidden_size),
|
||||
self.dtype,
|
||||
input_ids.device(),
|
||||
)?;
|
||||
for (idx, embs) in self.embed_tokens.iter().enumerate() {
|
||||
let e = input_ids.i((.., idx))?.apply(embs)?;
|
||||
inputs_embeds = (inputs_embeds + e)?
|
||||
}
|
||||
let inputs_embeds = match prompt_hidden_states {
|
||||
None => inputs_embeds,
|
||||
Some(pis) => Tensor::cat(&[pis, &inputs_embeds], 1)?,
|
||||
};
|
||||
let embed_positions = self
|
||||
.embed_positions
|
||||
.i(seqlen_offset..seqlen_offset + inputs_embeds.dim(1)?)?;
|
||||
let mut xs = (inputs_embeds + embed_positions.unsqueeze(0))?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask, encoder_xs, encoder_attention_mask)?;
|
||||
}
|
||||
let xs = xs.apply(&self.layer_norm)?;
|
||||
let mut lm_logits = Vec::with_capacity(self.num_codebooks);
|
||||
for lm_head in self.lm_heads.iter() {
|
||||
let logits = xs.apply(lm_head)?;
|
||||
lm_logits.push(logits)
|
||||
}
|
||||
Ok(lm_logits)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
pub embed_prompts: candle_nn::Embedding,
|
||||
pub enc_to_dec_proj: Option<Linear>,
|
||||
pub decoder: Decoder,
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub decoder_start_token_id: u32,
|
||||
pub pad_token_id: u32,
|
||||
pub audio_encoder: crate::models::dac::Model,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.text_encoder)?;
|
||||
let decoder = Decoder::new(&cfg.decoder, vb.pp("decoder"))?;
|
||||
let embed_prompts = candle_nn::embedding(
|
||||
cfg.vocab_size,
|
||||
cfg.decoder.hidden_size,
|
||||
vb.pp("embed_prompts"),
|
||||
)?;
|
||||
let enc_to_dec_proj = if cfg.text_encoder.d_model != cfg.decoder.hidden_size {
|
||||
let proj = linear(
|
||||
cfg.text_encoder.d_model,
|
||||
cfg.decoder.hidden_size,
|
||||
true,
|
||||
vb.pp("enc_to_dec_proj"),
|
||||
)?;
|
||||
Some(proj)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let audio_encoder =
|
||||
crate::models::dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?;
|
||||
Ok(Self {
|
||||
decoder,
|
||||
text_encoder,
|
||||
embed_prompts,
|
||||
enc_to_dec_proj,
|
||||
decoder_start_token_id: cfg.decoder_start_token_id,
|
||||
pad_token_id: cfg.pad_token_id,
|
||||
audio_encoder,
|
||||
})
|
||||
}
|
||||
|
||||
/// Note that the returned tensor uses the CPU device.
|
||||
pub fn generate(
|
||||
&mut self,
|
||||
prompt_tokens: &Tensor,
|
||||
description_tokens: &Tensor,
|
||||
mut lp: LogitsProcessor,
|
||||
max_steps: usize,
|
||||
) -> Result<Tensor> {
|
||||
self.decoder.clear_kv_cache();
|
||||
self.text_encoder.clear_kv_cache();
|
||||
let encoded = self.text_encoder.forward(description_tokens)?;
|
||||
let encoded = match self.enc_to_dec_proj.as_ref() {
|
||||
None => encoded,
|
||||
Some(proj) => encoded.apply(proj)?,
|
||||
};
|
||||
let prompt_hidden_states = prompt_tokens.apply(&self.embed_prompts)?;
|
||||
let num_codebooks = self.decoder.num_codebooks;
|
||||
let mut audio_tokens = vec![self.decoder_start_token_id; num_codebooks];
|
||||
let mut all_audio_tokens = vec![vec![]; num_codebooks];
|
||||
let prompt_len = prompt_hidden_states.dim(1)?;
|
||||
for step in 0..max_steps {
|
||||
let input_ids = Tensor::from_slice(
|
||||
audio_tokens.as_slice(),
|
||||
(1, num_codebooks, 1),
|
||||
prompt_tokens.device(),
|
||||
)?;
|
||||
let (prompt_hidden_states, pos) = if step == 0 {
|
||||
(Some(&prompt_hidden_states), 0)
|
||||
} else {
|
||||
(None, step + prompt_len)
|
||||
};
|
||||
let causal_mask = if pos == 0 {
|
||||
self.prepare_causal_mask(prompt_len + 1, prompt_len + 1, input_ids.device())?
|
||||
} else {
|
||||
self.prepare_causal_mask(1, pos + 1, input_ids.device())?
|
||||
};
|
||||
let logits = self.decoder.forward(
|
||||
&input_ids,
|
||||
prompt_hidden_states,
|
||||
Some(&causal_mask),
|
||||
&encoded,
|
||||
None,
|
||||
pos,
|
||||
)?;
|
||||
for (logit_idx, logit) in logits.iter().enumerate() {
|
||||
if logit_idx > step {
|
||||
break;
|
||||
}
|
||||
if audio_tokens[logit_idx] != self.pad_token_id {
|
||||
let logit = logit.i((0, logit.dim(1)? - 1))?;
|
||||
let token = lp.sample(&logit)?;
|
||||
audio_tokens[logit_idx] = token
|
||||
}
|
||||
}
|
||||
if audio_tokens.iter().all(|v| v == &self.pad_token_id) {
|
||||
break;
|
||||
}
|
||||
for (cb_idx, &token) in audio_tokens.iter().enumerate() {
|
||||
if token != self.decoder_start_token_id && token != self.pad_token_id {
|
||||
all_audio_tokens[cb_idx].push(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0);
|
||||
all_audio_tokens.iter_mut().for_each(|v| {
|
||||
v.resize(min_len, 0);
|
||||
});
|
||||
let all_audio_tokens = Tensor::new(all_audio_tokens, &candle::Device::Cpu)?;
|
||||
Ok(all_audio_tokens)
|
||||
}
|
||||
|
||||
fn prepare_causal_mask(
|
||||
&self,
|
||||
q_len: usize,
|
||||
kv_len: usize,
|
||||
device: &candle::Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..q_len)
|
||||
.flat_map(|i| {
|
||||
(0..kv_len).map(move |j| {
|
||||
if i + kv_len < j + q_len {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (q_len, kv_len), device)
|
||||
}
|
||||
}
|
@ -361,7 +361,7 @@ pub struct ModelForCausalLM {
|
||||
impl ModelForCausalLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let base_model = Model::new(cfg, vb.clone())?;
|
||||
let lm_head = if vb.contains_tensor("lm_head") {
|
||||
let lm_head = if vb.contains_tensor("lm_head.weight") {
|
||||
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
} else {
|
||||
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
|
||||
|
@ -404,7 +404,7 @@ impl SegformerEncoder {
|
||||
stride,
|
||||
num_channels,
|
||||
hidden_size,
|
||||
vb.pp(&format!("patch_embeddings.{}", i)),
|
||||
vb.pp(format!("patch_embeddings.{}", i)),
|
||||
)?);
|
||||
let mut layers = Vec::with_capacity(config.depths[i]);
|
||||
for j in 0..config.depths[i] {
|
||||
@ -417,14 +417,14 @@ impl SegformerEncoder {
|
||||
num_attention_heads,
|
||||
sequence_reduction_ratio,
|
||||
mlp_ratio,
|
||||
vb.pp(&format!("block.{}.{}", i, j)),
|
||||
vb.pp(format!("block.{}.{}", i, j)),
|
||||
)?);
|
||||
}
|
||||
blocks.push(layers);
|
||||
layer_norms.push(layer_norm(
|
||||
hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp(&format!("layer_norm.{}", i)),
|
||||
vb.pp(format!("layer_norm.{}", i)),
|
||||
)?);
|
||||
}
|
||||
Ok(Self {
|
||||
@ -507,7 +507,7 @@ impl SegformerDecodeHead {
|
||||
linear_c.push(SegformerMLP::new(
|
||||
config,
|
||||
hidden_size,
|
||||
vb.pp(&format!("linear_c.{}", i)),
|
||||
vb.pp(format!("linear_c.{}", i)),
|
||||
)?);
|
||||
}
|
||||
let linear_fuse = conv2d_no_bias(
|
||||
|
@ -659,7 +659,7 @@ struct T5Stack {
|
||||
impl T5Stack {
|
||||
fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
|
||||
let block = (0..cfg.num_layers)
|
||||
.map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg))
|
||||
.map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let final_layer_norm = T5LayerNorm::load(
|
||||
cfg.d_model,
|
||||
|
@ -260,7 +260,7 @@ impl AudioEncoder {
|
||||
let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
@ -321,7 +321,7 @@ impl TextDecoder {
|
||||
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
|
||||
let blocks = (0..cfg.decoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
|
@ -50,3 +50,61 @@ pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Updates confidences starting at highest and comparing subsequent boxes.
|
||||
fn update_confidences<D>(
|
||||
bboxes_for_class: &[Bbox<D>],
|
||||
updated_confidences: &mut [f32],
|
||||
iou_threshold: f32,
|
||||
sigma: f32,
|
||||
) {
|
||||
let len = bboxes_for_class.len();
|
||||
for current_index in 0..len {
|
||||
let current_bbox = &bboxes_for_class[current_index];
|
||||
for index in (current_index + 1)..len {
|
||||
let iou_val = iou(current_bbox, &bboxes_for_class[index]);
|
||||
if iou_val > iou_threshold {
|
||||
// Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
|
||||
let decay = (-iou_val * iou_val / sigma).exp();
|
||||
let updated_confidence = bboxes_for_class[index].confidence * decay;
|
||||
updated_confidences[index] = updated_confidence;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
|
||||
// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
|
||||
pub fn soft_non_maximum_suppression<D>(
|
||||
bboxes: &mut [Vec<Bbox<D>>],
|
||||
iou_threshold: Option<f32>,
|
||||
confidence_threshold: Option<f32>,
|
||||
sigma: Option<f32>,
|
||||
) {
|
||||
let iou_threshold = iou_threshold.unwrap_or(0.5);
|
||||
let confidence_threshold = confidence_threshold.unwrap_or(0.1);
|
||||
let sigma = sigma.unwrap_or(0.5);
|
||||
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
// Sort boxes by confidence in descending order
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut updated_confidences = bboxes_for_class
|
||||
.iter()
|
||||
.map(|bbox| bbox.confidence)
|
||||
.collect::<Vec<_>>();
|
||||
update_confidences(
|
||||
bboxes_for_class,
|
||||
&mut updated_confidences,
|
||||
iou_threshold,
|
||||
sigma,
|
||||
);
|
||||
// Update confidences, set to 0.0 if below threshold
|
||||
for (i, &confidence) in updated_confidences.iter().enumerate() {
|
||||
bboxes_for_class[i].confidence = if confidence < confidence_threshold {
|
||||
0.0
|
||||
} else {
|
||||
confidence
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
222
candle-transformers/tests/nms_tests.rs
Normal file
222
candle-transformers/tests/nms_tests.rs
Normal file
@ -0,0 +1,222 @@
|
||||
use candle::Result;
|
||||
use candle_transformers::object_detection::{
|
||||
non_maximum_suppression, soft_non_maximum_suppression, Bbox,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn nms_basic() -> Result<()> {
|
||||
// Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python
|
||||
let mut bboxes = vec![vec![
|
||||
Bbox {
|
||||
xmin: 245.0,
|
||||
ymin: 305.0,
|
||||
xmax: 575.0,
|
||||
ymax: 490.0,
|
||||
confidence: 0.9,
|
||||
data: (),
|
||||
}, // Box 1
|
||||
Bbox {
|
||||
xmin: 235.0,
|
||||
ymin: 300.0,
|
||||
xmax: 485.0,
|
||||
ymax: 515.0,
|
||||
confidence: 0.8,
|
||||
data: (),
|
||||
}, // Box 2
|
||||
Bbox {
|
||||
xmin: 305.0,
|
||||
ymin: 270.0,
|
||||
xmax: 540.0,
|
||||
ymax: 500.0,
|
||||
confidence: 0.6,
|
||||
data: (),
|
||||
}, // Box 3
|
||||
]];
|
||||
|
||||
non_maximum_suppression(&mut bboxes, 0.5);
|
||||
let bboxes = bboxes.into_iter().next().unwrap();
|
||||
assert_eq!(bboxes.len(), 1);
|
||||
assert_eq!(bboxes[0].confidence, 0.9);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softnms_basic_functionality() -> Result<()> {
|
||||
let mut bboxes = vec![vec![
|
||||
Bbox {
|
||||
xmin: 0.0,
|
||||
ymin: 0.0,
|
||||
xmax: 1.0,
|
||||
ymax: 1.0,
|
||||
confidence: 0.5,
|
||||
data: (),
|
||||
},
|
||||
Bbox {
|
||||
xmin: 0.1,
|
||||
ymin: 0.1,
|
||||
xmax: 1.1,
|
||||
ymax: 1.1,
|
||||
confidence: 0.9,
|
||||
data: (),
|
||||
},
|
||||
Bbox {
|
||||
xmin: 0.2,
|
||||
ymin: 0.2,
|
||||
xmax: 1.2,
|
||||
ymax: 1.2,
|
||||
confidence: 0.6,
|
||||
data: (),
|
||||
},
|
||||
]];
|
||||
|
||||
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
|
||||
|
||||
// Should decay boxes following highest confidence box
|
||||
assert!(bboxes[0][0].confidence == 0.9);
|
||||
assert!(bboxes[0][1].confidence < 0.5);
|
||||
assert!(bboxes[0][2].confidence < 0.6);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softnms_confidence_decay() -> Result<()> {
|
||||
let mut bboxes = vec![vec![
|
||||
Bbox {
|
||||
xmin: 0.0,
|
||||
ymin: 0.0,
|
||||
xmax: 1.0,
|
||||
ymax: 1.0,
|
||||
confidence: 0.9,
|
||||
data: (),
|
||||
}, // Reference box
|
||||
Bbox {
|
||||
xmin: 0.1,
|
||||
ymin: 0.1,
|
||||
xmax: 1.1,
|
||||
ymax: 1.1,
|
||||
confidence: 0.8,
|
||||
data: (),
|
||||
}, // Overlapping box
|
||||
]];
|
||||
|
||||
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
|
||||
|
||||
// Check that confidence of the overlapping box is decayed
|
||||
assert!(bboxes[0][0].confidence == 0.9);
|
||||
assert!(bboxes[0][1].confidence < 0.8);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softnms_confidence_threshold() -> Result<()> {
|
||||
let mut bboxes = vec![vec![
|
||||
Bbox {
|
||||
xmin: 0.0,
|
||||
ymin: 0.0,
|
||||
xmax: 1.0,
|
||||
ymax: 1.0,
|
||||
confidence: 0.9,
|
||||
data: (),
|
||||
},
|
||||
Bbox {
|
||||
xmin: 0.1,
|
||||
ymin: 0.1,
|
||||
xmax: 1.1,
|
||||
ymax: 1.1,
|
||||
confidence: 0.05,
|
||||
data: (),
|
||||
},
|
||||
]];
|
||||
|
||||
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
|
||||
|
||||
// Box with confidence below the threshold should be removed
|
||||
assert_eq!(bboxes[0].len(), 2);
|
||||
assert_eq!(bboxes[0][0].confidence, 0.9);
|
||||
assert_eq!(bboxes[0][1].confidence, 0.00);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softnms_no_overlap() -> Result<()> {
|
||||
let mut bboxes = vec![vec![
|
||||
Bbox {
|
||||
xmin: 0.0,
|
||||
ymin: 0.0,
|
||||
xmax: 1.0,
|
||||
ymax: 1.0,
|
||||
confidence: 0.9,
|
||||
data: (),
|
||||
},
|
||||
Bbox {
|
||||
xmin: 2.0,
|
||||
ymin: 2.0,
|
||||
xmax: 3.0,
|
||||
ymax: 3.0,
|
||||
confidence: 0.8,
|
||||
data: (),
|
||||
},
|
||||
]];
|
||||
|
||||
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
|
||||
|
||||
// Both boxes should remain as they do not significantly overlap
|
||||
assert_eq!(bboxes[0].len(), 2);
|
||||
assert_eq!(bboxes[0][0].confidence, 0.9);
|
||||
assert_eq!(bboxes[0][1].confidence, 0.8);
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn softnms_no_bbox() -> Result<()> {
|
||||
let mut bboxes: Vec<Vec<Bbox<()>>> = vec![];
|
||||
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
|
||||
assert!(bboxes.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softnms_single_bbox() -> Result<()> {
|
||||
let mut bboxes = vec![vec![Bbox {
|
||||
xmin: 0.0,
|
||||
ymin: 0.0,
|
||||
xmax: 1.0,
|
||||
ymax: 1.0,
|
||||
confidence: 0.9,
|
||||
data: (),
|
||||
}]];
|
||||
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
|
||||
assert_eq!(bboxes[0].len(), 1);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softnms_equal_confidence_overlap() -> Result<()> {
|
||||
let mut bboxes = vec![vec![
|
||||
Bbox {
|
||||
xmin: 0.0,
|
||||
ymin: 0.0,
|
||||
xmax: 1.0,
|
||||
ymax: 1.0,
|
||||
confidence: 0.5,
|
||||
data: (),
|
||||
},
|
||||
Bbox {
|
||||
xmin: 0.1,
|
||||
ymin: 0.1,
|
||||
xmax: 1.1,
|
||||
ymax: 1.1,
|
||||
confidence: 0.5,
|
||||
data: (),
|
||||
},
|
||||
]];
|
||||
|
||||
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
|
||||
|
||||
// First box will be reference box, second box should be decayed
|
||||
// Implementation must change to have both be decayed
|
||||
assert_eq!(bboxes[0].len(), 2);
|
||||
assert!(bboxes[0][0].confidence == 0.5);
|
||||
assert!(bboxes[0][1].confidence < 0.5);
|
||||
Ok(())
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user