Compare commits

..

10 Commits

182 changed files with 1684 additions and 23805 deletions

View File

@ -9,8 +9,7 @@ jobs:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g4dn-2xlarge
runs-on: [single-gpu, nvidia-gpu, t4, ci]
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0

View File

@ -18,9 +18,9 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
@ -65,4 +65,4 @@ jobs:
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests
python -m pytest -s -v tests

View File

@ -1,6 +1,6 @@
on:
on:
push:
branches:
branches:
- main
pull_request:
@ -15,7 +15,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -34,7 +34,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -49,7 +49,7 @@ jobs:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -65,7 +65,7 @@ jobs:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal

6
.gitignore vendored
View File

@ -40,9 +40,3 @@ candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
__pycache__
out.safetensors
out.wav
bria.mp3
bria.safetensors
bria.wav

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.7.2"
version = "0.6.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,20 +33,20 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.2" }
candle-datasets = { path = "./candle-datasets", version = "0.7.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.2" }
candle-kernels = { path = "./candle-kernels", version = "0.7.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.2" }
candle-nn = { path = "./candle-nn", version = "0.7.2" }
candle-onnx = { path = "./candle-onnx", version = "0.7.2" }
candle-transformers = { path = "./candle-transformers", version = "0.7.2" }
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
candle-nn = { path = "./candle-nn", version = "0.6.0" }
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "=0.11.6", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
@ -70,9 +70,6 @@ tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.0.2"
ug-cuda = "0.0.2"
ug-metal = "0.0.2"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -2,8 +2,7 @@
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)
![License](https://img.shields.io/crates/l/candle-core.svg)
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
@ -64,9 +63,7 @@ We also provide a some command line based examples using state of the art models
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
- [Gemma v1 and v2](./candle-examples/examples/gemma/): 2b and 7b+/9b general LLMs from Google Deepmind.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
@ -121,8 +118,6 @@ We also provide a some command line based examples using state of the art models
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [Parler-TTS](./candle-examples/examples/parler-tts/): large text-to-speech
model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
@ -188,7 +183,6 @@ And then head over to
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
If you have an addition to this list, please submit a pull request.
@ -212,7 +206,7 @@ If you have an addition to this list, please submit a pull request.
- StarCoder, StarCoder2.
- Phi 1, 1.5, 2, and 3.
- Mamba, Minimal Mamba
- Gemma v1 2b and 7b+, v2 2b and 9b.
- Gemma 2b and 7b.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
@ -240,10 +234,9 @@ If you have an addition to this list, please submit a pull request.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Parler-TTS, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.

View File

@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas
```rust
# extern crate candle_core;
# extern crate candle_hf_hub;
use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate candle_hf_hub;
# use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());

View File

@ -28,9 +28,6 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
ug = { workspace = true }
ug-cuda = { workspace = true, optional = true }
ug-metal = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }
@ -42,11 +39,11 @@ criterion = { workspace = true }
[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"

View File

@ -1,6 +1,6 @@
use crate::WithDType;
use cudarc;
use cudarc::cudnn::safe::{ConvForward, Cudnn};
use cudarc::cudnn::safe::{Conv2dForward, Cudnn};
use cudarc::driver::{CudaSlice, CudaView, DeviceRepr, ValidAsZeroBits};
use std::cell::RefCell;
use std::collections::HashMap;
@ -26,7 +26,6 @@ impl From<cudarc::driver::DriverError> for crate::Error {
pub(crate) fn launch_conv2d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
@ -49,7 +48,7 @@ pub(crate) fn launch_conv2d<
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [params.dilation as i32, params.dilation as i32],
@ -63,18 +62,18 @@ pub(crate) fn launch_conv2d<
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(
cudnn.create_4d_tensor_ex(
x_shape,
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
)?
};
let w = cudnn.create_4d_filter::<T>(
let w = cudnn.create_4d_filter(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
@ -84,11 +83,11 @@ pub(crate) fn launch_conv2d<
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor::<T>(
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;
let conv2d = ConvForward {
let conv2d = Conv2dForward {
conv: &conv,
x: &x,
w: &w,

View File

@ -51,27 +51,6 @@ impl CudaDevice {
self.device.clone()
}
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
}
pub fn id(&self) -> DeviceId {
self.id
}
@ -165,20 +144,6 @@ impl CudaDevice {
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;

View File

@ -174,7 +174,6 @@ impl Map1 for Im2Col1D {
}
}
#[allow(unused)]
struct Im2Col {
h_k: usize,
w_k: usize,
@ -184,7 +183,6 @@ struct Im2Col {
}
impl Im2Col {
#[allow(unused)]
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
@ -1522,7 +1520,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
}
@ -1530,10 +1528,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
// version.
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
crate::cudnn::launch_conv2d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::BF16(out)
}
@ -1541,7 +1536,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
}
@ -1549,7 +1544,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
}
@ -1557,7 +1552,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
}

View File

@ -375,110 +375,3 @@ impl Tensor {
)
}
}
pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
}
impl UgIOp1 {
#[allow(unused)]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(feature = "metal")]
{
let device = device.as_metal_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
}
}
impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on metal/cuda at the moment")
}
#[cfg(feature = "metal")]
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
use crate::backend::BackendStorage;
use candle_metal_kernels::utils::EncoderProvider;
let elem_count = layout.shape().elem_count();
if sto.dtype() != crate::DType::F32 {
// TODO: support more dtypes.
crate::bail!("input is not a f32 tensor")
}
let device = sto.device();
println!("here");
let command_buffer = device.command_buffer()?;
let command_buffer = &command_buffer;
let encoder = command_buffer.encoder();
let encoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&self.func);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let grid_dims = metal::MTLSize {
width: g as u64,
height: 1,
depth: 1,
};
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);
Ok(())
}
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;
let elem_count = layout.shape().elem_count();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}

View File

@ -130,26 +130,6 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
match self {
Self::Cuda(d) => Ok(d),
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
}
}
pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
match self {
Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
Self::Cpu => crate::bail!("expected a metal device, got cpu"),
Self::Metal(d) => Ok(d),
}
}
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}

View File

@ -14,12 +14,6 @@ macro_rules! fail {
};
}
impl CudaDevice {
pub fn new_with_stream(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}
impl crate::backend::BackendStorage for CudaStorage {
type Device = CudaDevice;

View File

@ -165,9 +165,6 @@ pub enum Error {
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)]
Ug(#[from] ug::Error),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
@ -182,10 +179,6 @@ pub enum Error {
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),
/// Utf8 parse error.
#[error(transparent)]
FromUtf8(#[from] std::string::FromUtf8Error),
/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),

View File

@ -141,117 +141,28 @@ impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0., 1.],
/// [2., 3.],
/// [4., 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i(0)?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
///
/// let c = a.i(..2)?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f64>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i(1..)?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f64>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}
impl<A> IndexOp<(A,)> for Tensor
where
A: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0f32, 1.],
/// [2. , 3.],
/// [4. , 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i((0,))?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
///
/// let c = a.i((..2,))?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i((1..,))?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f32>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
self.index(&[a.into()])
}
}
#[allow(non_snake_case)]
impl<A, B> IndexOp<(A, B)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
///
/// let b = a.i((1, 0))?;
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
///
/// let c = a.i((..2, 1))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
///
/// let d = a.i((2.., ..))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
self.index(&[a.into(), b.into()])
}
}
macro_rules! index_op_tuple {
($doc:tt, $($t:ident),+) => {
($($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
#[doc=$doc]
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);

View File

@ -65,7 +65,6 @@ pub mod scalar;
pub mod shape;
mod sort;
mod storage;
pub mod streaming;
mod strided_index;
mod tensor;
mod tensor_cat;
@ -77,15 +76,14 @@ mod variable;
pub use cuda_backend::cudnn;
pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use indexer::IndexOp;
pub use layout::Layout;
pub use shape::{Shape, D};
pub use storage::Storage;
pub use streaming::{StreamTensor, StreamingBinOp, StreamingModule};
pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use variable::Var;

View File

@ -4,7 +4,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path;
use std::sync::{Arc, Mutex, RwLock};
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
use super::MetalError;
@ -22,73 +22,7 @@ impl DeviceId {
}
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
pub(crate) struct Commands {
/// Single command queue for the entire device.
command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
command_buffer: CommandBuffer,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
command_buffer_index: usize,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
compute_per_buffer: usize,
}
impl Commands {
pub(crate) fn new(command_queue: CommandQueue) -> Result<Self> {
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
Ok(Self {
command_queue,
command_buffer,
command_buffer_index: 0,
compute_per_buffer,
})
}
pub fn command_buffer(&mut self) -> Result<(bool, CommandBuffer)> {
let mut command_buffer = self.command_buffer.to_owned();
let mut flushed = false;
if self.command_buffer_index > self.compute_per_buffer {
self.command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
self.command_buffer = command_buffer.clone();
self.command_buffer_index = 0;
flushed = true;
}
self.command_buffer_index += 1;
Ok((flushed, command_buffer))
}
pub fn wait_until_completed(&mut self) -> Result<()> {
match self.command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
self.command_buffer.commit();
self.command_buffer.wait_until_completed();
self.command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
}
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
#[derive(Clone)]
pub struct MetalDevice {
@ -99,8 +33,27 @@ pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
pub(crate) device: metal::Device,
pub(crate) commands: Arc<RwLock<Commands>>,
/// Single command queue for the entire device.
pub(crate) command_queue: CommandQueue,
/// One command buffer at a time.
/// The scheduler works by allowing multiple
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
/// to start to work).
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
/// for their START time, but there's no guarantee that command buffer1 will finish before
/// command buffer2 starts (or there are metal bugs there)
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
/// Keeps track of the current amount of compute command encoders on the current
/// command buffer
/// Arc, RwLock because of the interior mutability.
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
pub(crate) compute_per_buffer: usize,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
/// Simple allocator struct.
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
@ -114,15 +67,9 @@ pub struct MetalDevice {
///
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
/// (strong_count = 1).
pub(crate) buffers: Arc<RwLock<BufferMap>>,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`]
pub(crate) kernels: Arc<Kernels>,
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
/// Whether to use the MLX matmul kernels instead of the MFA ones.
pub(crate) use_mlx_mm: bool,
}
impl std::fmt::Debug for MetalDevice {
@ -140,32 +87,6 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
self.use_mlx_mm = use_mlx_mm
}
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<metal::ComputePipelineState> {
let mut buf = vec![];
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
let metal_code = String::from_utf8(buf)?;
let lib = self
.device
.new_library_with_source(&metal_code, &metal::CompileOptions::new())
.map_err(MetalError::from)?;
let func = lib
.get_function(func_name, None)
.map_err(MetalError::from)?;
let pl = self
.device
.new_compute_pipeline_state_with_function(&func)
.map_err(MetalError::from)?;
Ok(pl)
}
pub fn id(&self) -> DeviceId {
self.id
}
@ -174,31 +95,44 @@ impl MetalDevice {
&self.device
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut commands = self.commands.write().map_err(MetalError::from)?;
let (flushed, command_buffer) = commands.command_buffer()?;
if flushed {
self.drop_unused_buffers()?
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
command_buffer = self.command_queue.new_command_buffer().to_owned();
*command_buffer_lock = command_buffer.clone();
*index = 0;
self.drop_unused_buffers()?;
}
*index += 1;
Ok(command_buffer)
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut commands = self.commands.write().map_err(MetalError::from)?;
commands.wait_until_completed()
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Already committed");
}
_ => {}
}
command_buffer.commit();
command_buffer.wait_until_completed();
*command_buffer = self.command_queue.new_command_buffer().to_owned();
Ok(())
}
pub fn kernels(&self) -> &Kernels {
@ -246,7 +180,6 @@ impl MetalDevice {
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
@ -277,6 +210,40 @@ impl MetalDevice {
Ok(buffer)
}
fn find_available_buffer(
&self,
size: NSUInteger,
option: MTLResourceOptions,
buffers: &RwLockWriteGuard<BufferMap>,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
.filter(|s| Arc::strong_count(*s) > 1)
.map(Arc::clone)
.collect();
*subbuffers = newbuffers;
}
Ok(())
}
/// The critical allocator algorithm
fn allocate_buffer(
&self,
@ -285,7 +252,7 @@ impl MetalDevice {
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = find_available_buffer(size, option, &buffers) {
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());
}
@ -324,23 +291,3 @@ impl MetalDevice {
fn buf_size(size: NSUInteger) -> NSUInteger {
size.saturating_sub(1).next_power_of_two() as NSUInteger
}
fn find_available_buffer(
size: NSUInteger,
option: MTLResourceOptions,
buffers: &BufferMap,
) -> Option<Arc<Buffer>> {
let mut best_buffer: Option<&Arc<Buffer>> = None;
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
for sub in subbuffers {
if Arc::strong_count(sub) == 1 {
best_buffer = Some(sub);
best_buffer_size = *buffer_size;
}
}
}
}
best_buffer.cloned()
}

View File

@ -412,42 +412,17 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
} else {
let kernel_name = match (self.dtype, dtype) {
(DType::BF16, DType::F16) => "cast_bf16_f16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(DType::BF16, DType::I64) => "cast_bf16_i64_strided",
(DType::BF16, DType::U32) => "cast_bf16_u32_strided",
(DType::BF16, DType::U8) => "cast_bf16_u8_strided",
(DType::F16, DType::BF16) => "cast_f16_bf16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::F16, DType::I64) => "cast_f16_i64_strided",
(DType::F16, DType::U32) => "cast_f16_u32_strided",
(DType::F16, DType::U8) => "cast_f16_u8_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F32, DType::I64) => "cast_f32_i64_strided",
(DType::F32, DType::U32) => "cast_f32_u32_strided",
(DType::F32, DType::U8) => "cast_f32_u8_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::I64, DType::BF16) => "cast_i64_bf16_strided",
(DType::I64, DType::F16) => "cast_i64_f16_strided",
(DType::I64, DType::U32) => "cast_i64_u32_strided",
(DType::I64, DType::U8) => "cast_i64_u8_strided",
(DType::U32, DType::BF16) => "cast_u32_bf16_strided",
(DType::U32, DType::F16) => "cast_u32_f16_strided",
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::U32, DType::I64) => "cast_u32_i64_strided",
(DType::U32, DType::U8) => "cast_u32_u8_strided",
(DType::U8, DType::BF16) => "cast_u8_bf16_strided",
(DType::U8, DType::F16) => "cast_u8_f16_strided",
(DType::U32, DType::I64) => "cast_u32_i64_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(DType::U8, DType::F32) => "cast_u8_f32_strided",
(DType::U8, DType::I64) => "cast_u8_i64_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
(DType::F32, DType::BF16) => "cast_f32_bf16_strided",
(DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(left, right) => {
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
}
@ -1423,7 +1398,6 @@ impl BackendStorage for MetalStorage {
.map_err(MetalError::from)?;
Ok(acc)
}
fn matmul(
&self,
rhs: &Self,
@ -1432,78 +1406,32 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
DType::BF16 => "bgemm",
dtype => {
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
}
};
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
if self.dtype == DType::BF16 {
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
candle_metal_kernels::GemmDType::BF16,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else if self.device.use_mlx_mm {
let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"mlx matmul doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
} else {
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(
MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(),
)
}
};
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
}
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
buffer,
self.device.clone(),
@ -1864,25 +1792,31 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new());
let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true,
Ok(_) => false,
let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
let commands = device::Commands::new(command_queue)?;
Ok(Self {
id: DeviceId::new(),
device,
commands: Arc::new(RwLock::new(commands)),
buffers: Arc::new(RwLock::new(HashMap::new())),
command_queue,
command_buffer,
command_buffer_index,
compute_per_buffer,
buffers,
kernels,
seed,
use_mlx_mm,
})
}
@ -1917,38 +1851,10 @@ impl BackendDevice for MetalDevice {
))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {

View File

@ -6,15 +6,9 @@ use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
#[derive(Clone, Debug)]
struct PaddedCudaSlice {
inner: CudaSlice<u8>,
len: usize,
}
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: PaddedCudaSlice,
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
@ -67,7 +61,7 @@ fn quantize_q8_1(
}
fn dequantize_f32(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
@ -110,21 +104,21 @@ fn dequantize_f32(
};
if is_k {
let params = (&data.inner, &dst);
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_f16(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
@ -167,21 +161,21 @@ fn dequantize_f16(
};
if is_k {
let params = (&data.inner, &dst);
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
@ -190,7 +184,7 @@ fn dequantize_mul_mat_vec(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
@ -219,13 +213,13 @@ fn dequantize_mul_mat_vec(
shared_mem_bytes: 0,
};
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
let params = (data, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
@ -235,7 +229,7 @@ fn mul_mat_vec_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
@ -282,7 +276,7 @@ fn mul_mat_vec_via_q8_1(
};
let params = (
&data.inner,
data,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
@ -296,7 +290,7 @@ fn mul_mat_vec_via_q8_1(
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
@ -307,7 +301,7 @@ fn mul_mat_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
@ -321,7 +315,7 @@ fn mul_mat_via_q8_1(
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
@ -351,7 +345,7 @@ fn mul_mat_via_q8_1(
};
let params = (
/* vx */ &data.inner,
/* vx */ data,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
@ -367,14 +361,9 @@ fn mul_mat_via_q8_1(
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data: PaddedCudaSlice {
inner,
len: size_in_bytes,
},
data,
device: device.clone(),
dtype,
})
@ -414,10 +403,7 @@ impl QCudaStorage {
}
// Run the dequantization on cpu.
let buffer = self
.device
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
.w()?;
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
@ -458,21 +444,13 @@ impl QCudaStorage {
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
self.device
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
};
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len
self.data.len()
}
pub fn fwd(
@ -595,19 +573,11 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
.w()?;
let data = device.htod_sync_copy(data).w()?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
inner,
len: data.len(),
},
data,
device: device.clone(),
dtype,
dtype: T::DTYPE,
}))
}
@ -707,28 +677,4 @@ mod test {
assert_eq!(vs[15], 13138824.0);
Ok(())
}
// The following test used to fail under compute-sanitizer until #2526.
#[test]
fn cuda_mm_q8_1_pad() -> Result<()> {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ x_rows,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ y_cols,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
Ok(())
}
}

View File

@ -304,7 +304,6 @@ impl Dim for usize {
pub enum D {
Minus1,
Minus2,
Minus(usize),
}
impl D {
@ -312,7 +311,6 @@ impl D {
let dim = match self {
Self::Minus1 => -1,
Self::Minus2 => -2,
Self::Minus(u) => -(*u as i32),
};
Error::DimOutOfRange {
shape: shape.clone(),
@ -329,7 +327,6 @@ impl Dim for D {
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
@ -339,7 +336,6 @@ impl Dim for D {
match self {
Self::Minus1 => Ok(rank),
Self::Minus2 if rank >= 1 => Ok(rank - 1),
Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
_ => Err(self.out_of_range(shape, op)),
}
}

View File

@ -1,206 +0,0 @@
use crate::{Result, Shape, Tensor};
pub trait Dim: crate::shape::Dim + Copy {}
impl<T: crate::shape::Dim + Copy> Dim for T {}
/// A stream tensor is used in streaming module. It can either contain an actual tensor or be
/// empty.
#[derive(Clone)]
pub struct StreamTensor(Option<Tensor>);
impl std::fmt::Debug for StreamTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
Some(t) => write!(f, "{:?}", t.shape()),
None => write!(f, "Empty"),
}
}
}
impl std::convert::From<Option<Tensor>> for StreamTensor {
fn from(value: Option<Tensor>) -> Self {
Self(value)
}
}
impl std::convert::From<Tensor> for StreamTensor {
fn from(value: Tensor) -> Self {
Self(Some(value))
}
}
impl std::convert::From<()> for StreamTensor {
fn from(_value: ()) -> Self {
Self(None)
}
}
impl StreamTensor {
pub fn empty() -> Self {
Self(None)
}
pub fn from_tensor(tensor: Tensor) -> Self {
Self(Some(tensor))
}
pub fn shape(&self) -> Option<&Shape> {
self.0.as_ref().map(|t| t.shape())
}
pub fn cat2<D: Dim>(&self, rhs: &Self, dim: D) -> Result<Self> {
let xs = match (&self.0, &rhs.0) {
(Some(lhs), Some(rhs)) => {
let xs = Tensor::cat(&[lhs, rhs], dim)?;
Some(xs)
}
(Some(xs), None) | (None, Some(xs)) => Some(xs.clone()),
(None, None) => None,
};
Ok(Self(xs))
}
pub fn seq_len<D: Dim>(&self, dim: D) -> Result<usize> {
match &self.0 {
None => Ok(0),
Some(v) => v.dim(dim),
}
}
pub fn reset(&mut self) {
self.0 = None
}
pub fn narrow<D: Dim>(&self, dim: D, offset: usize, len: usize) -> Result<StreamTensor> {
let t = match &self.0 {
None => None,
Some(t) => {
let seq_len = t.dim(dim)?;
if seq_len <= offset {
None
} else {
let t = t.narrow(dim, offset, usize::min(len, seq_len - offset))?;
Some(t)
}
}
};
Ok(Self(t))
}
/// Splits the Streaming Tensor on the time axis `dim` with the first `lhs_len` elements
/// returned in the first output and the remaining in the second output.
pub fn split<D: Dim>(&self, dim: D, lhs_len: usize) -> Result<(Self, Self)> {
match &self.0 {
None => Ok((Self::empty(), Self::empty())),
Some(t) => {
let seq_len = t.dim(dim)?;
let lhs_len = usize::min(seq_len, lhs_len);
if lhs_len == 0 {
Ok((Self::empty(), t.clone().into()))
} else {
let lhs = Self::from_tensor(t.narrow(dim, 0, lhs_len)?);
let rhs_len = seq_len - lhs_len;
let rhs = if rhs_len == 0 {
Self::empty()
} else {
Self::from_tensor(t.narrow(dim, lhs_len, rhs_len)?)
};
Ok((lhs, rhs))
}
}
}
}
pub fn as_option(&self) -> Option<&Tensor> {
self.0.as_ref()
}
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
match &self.0 {
None => Ok(Self::empty()),
Some(t) => Ok(Self::from_tensor(t.apply(m)?)),
}
}
}
/// Streaming modules take as input a stream tensor and return a stream tensor. They may perform
/// some internal buffering so that enough data has been received for the module to be able to
/// perform some operations.
pub trait StreamingModule {
// TODO: Should we also have a flush method?
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor>;
fn reset_state(&mut self);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BinOp {
Add,
Mul,
Sub,
Div,
}
#[derive(Debug, Clone)]
pub struct StreamingBinOp {
prev_lhs: StreamTensor,
prev_rhs: StreamTensor,
pub op: BinOp,
pub dim: crate::D,
}
impl StreamingBinOp {
pub fn new(op: BinOp, dim: crate::D) -> Self {
Self {
prev_lhs: StreamTensor::empty(),
prev_rhs: StreamTensor::empty(),
op,
dim,
}
}
pub fn reset_state(&mut self) {
self.prev_lhs.reset();
self.prev_rhs.reset();
}
pub fn forward(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor> {
match self.op {
BinOp::Add => Tensor::add(lhs, rhs),
BinOp::Mul => Tensor::mul(lhs, rhs),
BinOp::Sub => Tensor::sub(lhs, rhs),
BinOp::Div => Tensor::div(lhs, rhs),
}
}
pub fn step(&mut self, lhs: &StreamTensor, rhs: &StreamTensor) -> Result<StreamTensor> {
let lhs = StreamTensor::cat2(&self.prev_lhs, lhs, self.dim)?;
let rhs = StreamTensor::cat2(&self.prev_rhs, rhs, self.dim)?;
let lhs_len = lhs.seq_len(self.dim)?;
let rhs_len = rhs.seq_len(self.dim)?;
let common_len = usize::min(lhs_len, rhs_len);
let (lhs, prev_lhs) = lhs.split(self.dim, common_len)?;
let (rhs, prev_rhs) = rhs.split(self.dim, common_len)?;
let ys = match (lhs.0, rhs.0) {
(Some(lhs), Some(rhs)) => {
let ys = self.forward(&lhs, &rhs)?;
StreamTensor::from_tensor(ys)
}
(None, None) => StreamTensor::empty(),
(lhs, rhs) => crate::bail!("INTERNAL ERROR inconsistent lhs and rhs {lhs:?} {rhs:?}"),
};
self.prev_lhs = prev_lhs;
self.prev_rhs = prev_rhs;
Ok(ys)
}
}
/// Simple wrapper that doesn't do any buffering.
pub struct Map<T: crate::Module>(T);
impl<T: crate::Module> StreamingModule for Map<T> {
fn reset_state(&mut self) {}
fn step(&mut self, xs: &StreamTensor) -> Result<StreamTensor> {
xs.apply(&self.0)
}
}

View File

@ -370,15 +370,6 @@ impl Tensor {
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [3.5, 3.5, 3.5, 3.5],
/// [3.5, 3.5, 3.5, 3.5],
/// ]);
/// # Ok::<(), candle_core::Error>(())
pub fn full<D: crate::WithDType, S: Into<Shape>>(
value: D,
shape: S,
@ -388,13 +379,6 @@ impl Tensor {
}
/// Creates a new 1D tensor from an iterator.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
device: &Device,
@ -406,26 +390,12 @@ impl Tensor {
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `1` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
Self::arange_step(start, end, D::one(), device)
}
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `step` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange_step<D: crate::WithDType>(
start: D,
end: D,
@ -471,16 +441,6 @@ impl Tensor {
/// Creates a new tensor initialized with values from the input vector. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
/// If the device is cpu, no data copy is made.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [1., 2., 3.],
/// [4., 5., 6.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
@ -491,17 +451,6 @@ impl Tensor {
/// Creates a new tensor initialized with values from the input slice. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
///```rust
/// use candle_core::{Tensor, Device};
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [2., 3., 4.],
/// [5., 6., 7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
@ -783,30 +732,6 @@ impl Tensor {
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
/// ```
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[
/// [0f32, 1., 2.],
/// [3. , 4., 5.],
/// [6. , 7., 8.]
/// ], &Device::Cpu)?;
///
/// let b = a.narrow(0, 1, 2)?;
/// assert_eq!(b.shape().dims(), &[2, 3]);
/// assert_eq!(b.to_vec2::<f32>()?, &[
/// [3., 4., 5.],
/// [6., 7., 8.]
/// ]);
///
/// let c = a.narrow(1, 1, 1)?;
/// assert_eq!(c.shape().dims(), &[3, 1]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [1.],
/// [4.],
/// [7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
@ -1520,15 +1445,14 @@ impl Tensor {
/// # Arguments
///
/// * `self` - The input tensor.
/// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
/// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
/// but can have a different number of elements on the target dimension.
/// * `dim` - the target dimension.
///
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
/// dimension `dim` by the values in `indexes`.
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
let self_dims = self.dims();
let indexes_dims = indexes.dims();
let mismatch = if indexes_dims.len() != self_dims.len() {
@ -1536,7 +1460,7 @@ impl Tensor {
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
if i != dim && d1 < d2 {
if i != dim && d1 != d2 {
mismatch = true;
break;
}
@ -2026,11 +1950,7 @@ impl Tensor {
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
_ => {
bail!(
"not implemented yet, self.device: {:?}, device: {:?}",
self.device(),
device
)
bail!("not implemented yet")
}
};
let op = BackpropOp::new1(self, Op::ToDevice);

View File

@ -143,39 +143,3 @@ fn inplace_op1() -> Result<()> {
);
Ok(())
}
#[cfg(any(feature = "cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
let kernel = {
use ug::lang::op;
let layout = ug::Layout::from_shape(&[12]);
let ptr = op::Arg::ptr(ug::DType::F32);
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
let src = op::unary(op::UnaryOp::Exp, src)?;
let st = op::store(ptr.id(), layout, src)?;
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
let device = if candle_core::utils::cuda_is_available() {
Device::new_cuda(0)?
} else if candle_core::utils::metal_is_available() {
Device::new_metal(0)?
} else {
candle_core::bail!("metal/cuda is mandatory for this test")
};
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
to_vec1_round(&t, 2)?,
&[
1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
59874.13
]
);
Ok(())
}

View File

@ -29,36 +29,6 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
],
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
]
],
);
assert_eq!(
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
[
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
],
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
]
],
);
Ok(())
}
@ -223,19 +193,6 @@ fn unary_op(device: &Device) -> Result<()> {
tensor.sign()?.to_vec1::<f32>()?,
[-1., -1., -1., 0., 0., 1., 1., 1., 1.]
);
let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?;
let y = tensor.elu(2.)?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
);
// This test failed on metal prior to the following PR:
// https://github.com/huggingface/candle/pull/2490
let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, -1.7293, 0.0000, 3.0000]
);
Ok(())
}
@ -1047,280 +1004,6 @@ fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
// Random data
// Dim: 0
let t = Tensor::new(
&[
[
[108_f32, -47., 16., -56., -83., -130., 210.],
[253., 95., 151., 228., -210., -123., -127.],
[-9., -217., 2., -78., 163., 245., -204.],
[-246., 79., -238., 88., -226., -184., 171.],
[8., -48., -153., 234., -34., 166., -153.],
[124., 0., -10., -61., -242., -15., -238.],
],
[
[12., -64., -199., 244., -240., 156., -128.],
[173., -57., 4., -198., 233., -110., 238.],
[95., 82., 0., 240., 53., -211., 209.],
[-122., 167., -212., 227., -144., 61., 118.],
[-63., -146., 200., 244., 168., -167., 116.],
[-125., -147., 110., -253., -178., -250., -18.],
],
[
[57., 86., -50., 56., 92., 205., -78.],
[-137., -156., -18., 248., -61., -239., 14.],
[-248., -30., -50., -70., -251., 250., -83.],
[-221., 67., 72., 59., -24., -154., 232.],
[-144., -23., -74., 5., 93., 171., 205.],
[46., -77., -38., -226., 246., 161., -17.],
],
[
[-153., -231., -236., 161., 126., 2., -22.],
[-229., -41., 209., 164., 234., 160., 57.],
[223., 254., -186., -162., -46., -160., -102.],
[65., 30., 213., -253., 59., 224., -154.],
[-82., -203., -177., 17., 31., -256., -246.],
[176., -135., -65., 54., -56., 210., 76.],
],
[
[-10., -245., 168., 124., -14., -33., -178.],
[25., -43., -39., 132., -89., 169., 179.],
[187., -215., 32., -133., 87., -7., -168.],
[-224., -215., -5., -230., -58., -162., 128.],
[158., -137., -122., -100., -202., -83., 136.],
[30., -185., -144., 250., 209., -40., 127.],
],
[
[-196., 108., -245., 122., 146., -228., 62.],
[-1., -66., 160., 137., 13., -172., -21.],
[244., 199., -164., 28., 119., -175., 198.],
[-62., 253., -162., 195., -95., -230., -211.],
[123., -72., -26., -107., -139., 64., 245.],
[11., -126., -182., 108., -12., 184., -127.],
],
[
[-159., 126., 176., 161., 73., -111., -138.],
[-187., 214., -217., -33., -223., -201., -212.],
[-61., -120., -166., -172., -95., 53., 196.],
[-33., 86., 134., -152., 154., -53., 74.],
[186., -28., -154., -174., 141., -109., 217.],
[82., 35., 252., 145., 181., 74., -87.],
],
],
device,
)?;
let ids = Tensor::new(
&[
[
[6_u32, 6, 4, 3, 4, 4, 6],
[3, 3, 2, 4, 4, 4, 6],
[3, 3, 0, 2, 4, 6, 4],
[2, 5, 1, 2, 6, 6, 1],
[2, 1, 6, 5, 3, 2, 3],
[6, 1, 0, 1, 0, 2, 6],
],
[
[4, 6, 4, 3, 3, 3, 2],
[4, 3, 2, 4, 4, 4, 6],
[2, 3, 0, 2, 4, 6, 4],
[6, 5, 1, 2, 6, 6, 1],
[4, 1, 6, 5, 3, 2, 3],
[1, 1, 0, 1, 0, 2, 6],
],
[
[3, 6, 4, 3, 3, 3, 2],
[2, 3, 2, 4, 4, 4, 6],
[4, 3, 0, 2, 4, 6, 4],
[0, 5, 1, 2, 6, 6, 1],
[6, 1, 6, 5, 3, 2, 3],
[4, 1, 0, 1, 0, 2, 6],
],
[
[0, 6, 4, 3, 3, 3, 2],
[5, 3, 2, 4, 4, 4, 6],
[0, 3, 0, 2, 4, 6, 4],
[3, 5, 1, 2, 6, 6, 1],
[0, 1, 6, 5, 3, 2, 3],
[3, 1, 0, 1, 0, 2, 6],
],
],
device,
)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[
[-159_f32, 126., 168., 161., -14., -33., -138.],
[-229., -41., -18., 132., -89., 169., -212.],
[223., 254., 2., -70., 87., 53., -168.],
[-221., 253., -212., 59., 154., -53., 118.],
[-144., -146., -154., -107., 31., 171., -246.],
[82., -147., -10., -253., -242., 161., -87.]
],
[
[-10., 126., 168., 161., 126., 2., -78.],
[25., -41., -18., 132., -89., 169., -212.],
[-248., 254., 2., -70., 87., 53., -168.],
[-33., 253., -212., 59., 154., -53., 118.],
[158., -146., -154., -107., 31., 171., -246.],
[-125., -147., -10., -253., -242., 161., -87.]
],
[
[-153., 126., 168., 161., 126., 2., -78.],
[-137., -41., -18., 132., -89., 169., -212.],
[187., 254., 2., -70., 87., 53., -168.],
[-246., 253., -212., 59., 154., -53., 118.],
[186., -146., -154., -107., 31., 171., -246.],
[30., -147., -10., -253., -242., 161., -87.]
],
[
[108., 126., 168., 161., 126., 2., -78.],
[-1., -41., -18., 132., -89., 169., -212.],
[-9., 254., 2., -70., 87., 53., -168.],
[65., 253., -212., 59., 154., -53., 118.],
[8., -146., -154., -107., 31., 171., -246.],
[176., -147., -10., -253., -242., 161., -87.]
]
]
);
// Dim: 1
let t = Tensor::new(
&[
[
[-117_f32, -175., 69., -163.],
[200., 242., -21., -67.],
[179., 150., -126., -75.],
[-118., 38., -138., -13.],
[-221., 136., -185., 180.],
[58., 182., -204., -149.],
],
[
[3., -148., -58., -154.],
[-43., 45., -108., 4.],
[-69., -249., -71., -21.],
[80., 110., -152., -235.],
[-88., 7., 92., -250.],
[-186., 207., -242., 98.],
],
[
[238., 19., 64., -242.],
[-150., -97., 218., 58.],
[111., -233., 204., -212.],
[-242., -232., 83., 42.],
[153., 62., -251., 219.],
[-117., 36., -119., 10.],
],
[
[215., 159., -169., -27.],
[-83., 101., -88., 169.],
[-205., 93., 225., -64.],
[-162., 240., 214., 23.],
[-112., 6., 21., 245.],
[-38., 113., 93., 215.],
],
[
[91., -188., -148., 101.],
[74., 203., -35., 55.],
[-116., -130., -153., -96.],
[58., 22., -45., -194.],
[-221., -134., 73., 159.],
[-203., -254., 31., 235.],
],
[
[105., -53., 61., 186.],
[-195., 234., 75., -1.],
[51., 139., 160., -108.],
[-173., -167., 161., 19.],
[83., -246., 156., -222.],
[109., 39., -149., 137.],
],
],
device,
)?;
let ids = Tensor::new(
&[
[[4_u32, 4, 4, 2]],
[[0, 4, 4, 3]],
[[1, 5, 3, 4]],
[[0, 3, 3, 2]],
[[1, 1, 5, 2]],
[[1, 4, 5, 4]],
],
device,
)?;
let hs = t.gather(&ids, 1)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[-221., 136., -185., -75.]],
[[3., 7., 92., -235.]],
[[-150., 36., 83., 219.]],
[[215., 240., 214., -64.]],
[[74., 203., 31., -96.]],
[[-195., -246., -149., -222.]]
]
);
// Dim: 2
let t = Tensor::new(
&[
[[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]],
[[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]],
],
device,
)?;
let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?;
let hs = t.gather(&ids, 2)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[202.], [-126.], [-65.], [80.]],
[[37.], [89.], [117.], [220.]]
]
);
let t = Tensor::new(
&[
[[-21_f32, -197.], [194., 122.]],
[[255., -106.], [-191., 250.]],
[[33., -117.], [43., 10.]],
[[-130., 238.], [-217., -92.]],
],
device,
)?;
let ids = Tensor::new(
&[
[[0_u32, 1], [1, 0]],
[[1, 0], [0, 1]],
[[0, 1], [0, 1]],
[[1, 0], [1, 0]],
],
device,
)?;
let hs = t.gather(&ids, 2)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[-21., -197.], [122., 194.]],
[[-106., 255.], [-191., 250.]],
[[33., -117.], [43., 10.]],
[[238., -130.], [-92., -217.]]
]
);
Ok(())
}

View File

@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.22.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }
@ -36,7 +36,6 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
[dev-dependencies]
anyhow = { workspace = true }
@ -66,9 +65,8 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
@ -103,10 +101,6 @@ required-features = ["candle-datasets"]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "mimi"
required-features = ["mimi"]
[[example]]
name = "encodec"
required-features = ["encodec"]
@ -114,11 +108,3 @@ required-features = ["encodec"]
[[example]]
name = "depth_anything_v2"
required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

View File

@ -1,20 +0,0 @@
# candle-based
Experimental, not instruction-tuned small LLM from the Hazy Research group, combining local and linear attention layers.
[Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
[Simple linear attention language models balance the recall-throughput tradeoff](https://arxiv.org/abs/2402.18668)
## Running an example
```bash
$ cargo run --example based --release -- --prompt "Flying monkeys are" --which 1b-50b --sample-len 100
Flying monkeys are a common sight in the wild, but they are also a threat to humans.
The new study, published today (July 31) in the journal Science Advances, shows that the monkeys are using their brains to solve the problem of how to get around the problem.
"We found that the monkeys were using a strategy called 'cognitive mapping' - they would use their brains to map out the route ahead," says lead author Dr. David J. Smith from the University of California
```

View File

@ -1,275 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::based::Model;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "360m")]
W360m,
#[value(name = "1b")]
W1b,
#[value(name = "1b-50b")]
W1b50b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "refs/pr/1")]
revision: String,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value = "360m")]
which: Which,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::W360m => "hazyresearch/based-360m".to_string(),
Which::W1b => "hazyresearch/based-1b".to_string(),
Which::W1b50b => "hazyresearch/based-1b-50b".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let config_file = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
let repo = api.model("openai-community/gpt2".to_string());
let tokenizer_file = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let mut vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
if args.which == Which::W1b50b {
vb = vb.pp("model");
};
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,224 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, Device, Tensor};
use candle_nn as nn;
use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
use clap::Parser;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
tracing_subscriber::fmt::init();
let device = candle_examples::device(args.cpu)?;
let var = load_weights(args.model, &device)?;
let clip_model = ChineseClipModel::new(var, &ChineseClipConfig::clip_vit_base_patch16())?;
tracing::info!("Transformer loaded. ");
let (pixel_values, vec_imgs) = load_images(args.images, &device)?;
tracing::info!("Images loaded. ");
let tokenizer = load_tokenizer()?;
let (input_ids, type_ids, attention_mask, text_sequences) =
tokenize_sequences(args.sequences, &tokenizer, &device)?;
tracing::info!("Computing ... ");
let (_logits_per_text, logits_per_image) = clip_model.forward(
&pixel_values,
&input_ids,
Some(&type_ids),
Some(&attention_mask),
)?;
let softmax_image = nn::ops::softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
tracing::info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
tracing::info!("Probability: {:.4}% Text: {} ", p, text_sequences[i]);
}
}
Ok(())
}
pub fn load_weights(model: Option<String>, device: &Device) -> anyhow::Result<nn::VarBuilder> {
let model_file = match model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
Ok(unsafe { nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? })
}
pub fn load_tokenizer() -> anyhow::Result<Tokenizer> {
let tokenizer_file = {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("tokenizer.json")?
};
Tokenizer::from_file(tokenizer_file).map_err(anyhow::Error::msg)
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Tensor, Tensor, Vec<String>)> {
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"自行车比赛".to_string(),
"两只猫咪".to_string(),
"拿着蜡烛的机器人".to_string(),
],
};
let mut input_ids = vec![];
let mut type_ids = vec![];
let mut attention_mask = vec![];
let mut max_len = 0;
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(anyhow::Error::msg)?;
input_ids.push(encoding.get_ids().to_vec());
type_ids.push(encoding.get_type_ids().to_vec());
attention_mask.push(encoding.get_attention_mask().to_vec());
if encoding.get_ids().len() > max_len {
max_len = encoding.get_ids().len();
}
}
let pad_id = *tokenizer
.get_vocab(true)
.get("[PAD]")
.ok_or(anyhow::Error::msg("No pad token"))?;
let input_ids: Vec<Vec<u32>> = input_ids
.iter_mut()
.map(|item| {
item.extend(vec![pad_id; max_len - item.len()]);
item.to_vec()
})
.collect();
let type_ids: Vec<Vec<u32>> = type_ids
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let attention_mask: Vec<Vec<u32>> = attention_mask
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let input_ids = Tensor::new(input_ids, device)?;
let type_ids = Tensor::new(type_ids, device)?;
let attention_mask = Tensor::new(attention_mask, device)?;
Ok((input_ids, type_ids, attention_mask, vec_seq))
}
pub fn load_images(
images: Option<Vec<String>>,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let vec_imgs = match images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let mut images = vec![];
for path in vec_imgs.iter() {
let tensor = load_image(path, 224, device)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?.to_device(device)?;
Ok((images, vec_imgs))
}
fn load_image<T: AsRef<std::path::Path>>(
path: T,
image_size: usize,
device: &Device,
) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8().into_raw();
let img = Tensor::from_vec(img, (height, width, 3), device)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?;
let std =
Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?;
let img = (img.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)?;
Ok(img)
}

View File

@ -12,6 +12,7 @@ use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)]
struct Args {
@ -39,12 +40,15 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}
@ -53,16 +57,24 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
@ -77,9 +89,13 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@ -87,29 +103,43 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
info!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
@ -126,6 +156,7 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
@ -138,6 +169,7 @@ pub fn tokenize_sequences(
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
@ -146,12 +178,16 @@ pub fn tokenize_sequences(
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
@ -159,6 +195,8 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -1,18 +0,0 @@
# Colpali
[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)
```
wget https://arxiv.org/pdf/1706.03762.pdf
cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
```
```
Prompt: what is position encoding?
top 3 page numbers that contain similarity to the prompt
-----------------------------------
Page: 6
Page: 11
Page: 15
-----------------------------------
```

View File

@ -1,268 +0,0 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::colpali::Model;
use candle_transformers::models::{colpali, paligemma};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use image::DynamicImage;
use pdf2image::{RenderOptionsBuilder, PDF};
use tokenizers::Tokenizer;
struct PageRetriever {
model: Model,
config: paligemma::Config,
pdf: PDF,
device: Device,
tokenizer: Tokenizer,
range: pdf2image::Pages,
batch_size: usize,
top_k: usize,
}
impl PageRetriever {
fn new(
model: Model,
config: paligemma::Config,
pdf: PDF,
tokenizer: Tokenizer,
device: &Device,
range: Option<pdf2image::Pages>,
batch_size: usize,
top_k: usize,
) -> Self {
let page_count = pdf.page_count();
Self {
model,
config,
pdf,
device: device.clone(),
tokenizer,
range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
batch_size,
top_k,
}
}
fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
let pages = self
.pdf
.render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
Ok(pages)
}
fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.device)
})
.collect::<candle::Result<Vec<_>>>()?;
let input = Tensor::stack(&token_ids, 0)?;
Ok(input)
}
fn images_to_tensor(
&self,
pages: &[DynamicImage],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for page in pages.iter() {
let img = page.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
images.push(img);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
let dtype = if self.device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let dummy_prompt: &str = "Describe the image";
let input = self.tokenize_batch(vec![prompt])?;
let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;
let pages = self.get_images_from_pdf()?;
let mut all_scores = Vec::new();
for batch in pages.chunks(self.batch_size) {
let page_images = self
.images_to_tensor(batch, self.config.vision_config.image_size)?
.to_device(&self.device)?
.to_dtype(dtype)?;
let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;
let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
let text_embeddings = self.model.forward_text(&input)?;
let scores = text_embeddings
.unsqueeze(1)?
.broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
.max(3)?
.sum(2)?;
let batch_scores: Vec<f32> = scores
.to_dtype(DType::F32)?
.to_vec2()?
.into_iter()
.flatten()
.collect();
all_scores.extend(batch_scores);
}
let mut indices: Vec<usize> = (0..all_scores.len()).collect();
indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());
let top_k_indices = indices[0..self.top_k].to_vec();
Ok(top_k_indices)
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// number of top pages to show.
#[arg(long, default_value_t = 3)]
top_k: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
pdf: String,
#[arg(long)]
start: Option<u32>,
#[arg(long)]
end: Option<u32>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "vidore/colpali-v1.2-merged".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.repo(Repo::with_revision(
"vidore/colpali".to_string(),
RepoType::Model,
"main".to_string(),
))
.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let start = std::time::Instant::now();
let config: paligemma::Config = paligemma::Config::paligemma_3b_448();
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(false)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = colpali::Model::new(&config, vb)?;
let pdf = PDF::from_file(args.pdf)?;
// check if start and end given in arg
let range = if let (Some(start), Some(end)) = (args.start, args.end) {
pdf2image::Pages::Range(start..=end)
} else {
pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
};
let mut retriever =
PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
let top_k_indices = retriever.retrieve(&args.prompt)?;
println!("Prompt: {}", args.prompt);
println!(
"top {} page numbers that contain similarity to the prompt",
retriever.top_k
);
println!("-----------------------------------");
for index in top_k_indices {
println!("Page: {:?}", index + 1);
}
println!("-----------------------------------");
Ok(())
}

View File

@ -7,7 +7,7 @@ quantization.
## Running one example
```bash
cargo run --example encodec --features encodec --release -- code-to-audio \
cargo run --example encodec --features symphonia --release -- code-to-audio \
candle-examples/examples/encodec/jfk-codes.safetensors \
jfk.wav
```

View File

@ -1,3 +1,4 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};

View File

@ -1,20 +0,0 @@
# candle-fastvit
[FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization](https://arxiv.org/abs/2303.14189).
This candle implementation uses a pre-trained FastViT network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example fastvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which sa12
loaded image Tensor[dims 3, 256, 256; f32]
model built
mountain bike, all-terrain bike, off-roader: 52.67%
bicycle-built-for-two, tandem bicycle, tandem: 7.93%
unicycle, monocycle : 3.46%
maillot : 1.32%
crash helmet : 1.28%
```

View File

@ -1,102 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::fastvit;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
T8,
T12,
S12,
SA12,
SA24,
SA36,
MA36,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::T8 => "t8",
Self::T12 => "t12",
Self::S12 => "s12",
Self::SA12 => "sa12",
Self::SA24 => "sa24",
Self::SA36 => "sa36",
Self::MA36 => "ma36",
};
format!("timm/fastvit_{}.apple_in1k", name)
}
fn config(&self) -> fastvit::Config {
match self {
Self::T8 => fastvit::Config::t8(),
Self::T12 => fastvit::Config::t12(),
Self::S12 => fastvit::Config::s12(),
Self::SA12 => fastvit::Config::sa12(),
Self::SA24 => fastvit::Config::sa24(),
Self::SA36 => fastvit::Config::sa36(),
Self::MA36 => fastvit::Config::ma36(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S12)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image(args.image, 256)?.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = fastvit::fastvit(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -13,7 +13,7 @@ descriptions,
```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024 \
--height 1024 --width 1024
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

View File

@ -23,10 +23,6 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Use the quantized model.
#[arg(long)]
quantized: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -44,14 +40,6 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
/// Use the slower kernels.
#[arg(long)]
use_dmmv: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@ -72,8 +60,6 @@ fn run(args: Args) -> Result<()> {
tracing,
decode_only,
model,
quantized,
..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -95,9 +81,6 @@ fn run(args: Args) -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
if let Some(seed) = args.seed {
device.set_seed(seed)?;
}
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
@ -163,71 +146,38 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
Model::Dev => bf_repo.get("flux1-dev.sft")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = if quantized {
flux::sampling::State::new(
&t5_emb.to_dtype(candle::DType::F32)?,
&clip_emb.to_dtype(candle::DType::F32)?,
&img.to_dtype(candle::DType::F32)?,
)?
} else {
flux::sampling::State::new(&t5_emb, &clip_emb, &img)?
};
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;
println!("{state:?}");
println!("{timesteps:?}");
if quantized {
let model_file = match model {
Model::Schnell => api
.repo(hf_hub::Repo::model("lmz/candle-flux".to_string()))
.get("flux1-schnell.gguf")?,
Model::Dev => todo!(),
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
model_file, &device,
)?;
let model = flux::quantized_model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
.to_dtype(dtype)?
} else {
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.safetensors")?,
Model::Dev => bf_repo.get("flux1-dev.safetensors")?,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)?
};
let model = flux::model::Flux::new(&cfg, vb)?;
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
}
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
};
flux::sampling::unpack(&img, height, width)?
}
@ -239,7 +189,7 @@ fn run(args: Args) -> Result<()> {
println!("latent img\n{img}");
let img = {
let model_file = bf_repo.get("ae.safetensors")?;
let model_file = bf_repo.get("ae.sft")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = match model {
Model::Dev => flux::autoencoder::Config::dev(),
@ -256,7 +206,5 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.use_dmmv);
run(args)
}

View File

@ -1,6 +0,0 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE_MODEL = "google/t5-v1_1-xxl"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# The tokenizer will be saved in /tmp/tokenizer/tokenizer.json
tokenizer.save_pretrained("/tmp/tokenizer/")

View File

@ -1,27 +1,27 @@
# candle-gemma: 2b and 7b LLMs from Google DeepMind
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
models published by Google Deepmind with a 2b and a 7b variant for the first
version, and a 2b and a 9b variant for v2.
models published by Google Deepmind with a 2b and a 7b variant.
## Running the example
```bash
$ cargo run --example gemma --features cuda -r -- \
--prompt "Here is a proof that square root of 2 is not rational: "
Here is a proof that square root of 2 is not rational:
Let us assume it to be rational. Then, we can write √2 = p/q where q ≠ 0 and p and q are integers with no common factors other than 1. Squaring both sides gives us (p/q)^2 = 2 or p^2/q^2 = 2. This implies that p^2 is divisible by 2, which means that p must be even. Let us write p = 2m where m is an integer. Substituting this in the above equation we get:
(p^2)/q^2 = 2 or (4m^2)/q^2 = 2 or q^2/2m^2 = 1 which implies that q^2 must be divisible by 2, and hence q is even. This contradicts our assumption that p and q have no common factors other than 1. Hence we conclude that √2 cannot be rational.
```
## Access restrictions
In order to use the v1 examples, you have to accept the license on the
In order to use the example below, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
## Running the example
```bash
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
fn count_primes(max_n: usize) -> usize {
let mut primes = vec![true; max_n];
for i in 2..=max_n {
if primes[i] {
for j in i * i..max_n {
primes[j] = false;
}
}
}
primes.len()
}
```

View File

@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
use candle_transformers::models::gemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -39,46 +38,6 @@ enum Which {
CodeInstruct2B,
#[value(name = "code-7b-it")]
CodeInstruct7B,
#[value(name = "2-2b")]
BaseV2_2B,
#[value(name = "2-2b-it")]
InstructV2_2B,
#[value(name = "2-9b")]
BaseV2_9B,
#[value(name = "2-9b-it")]
InstructV2_9B,
}
impl Which {
fn is_v1(&self) -> bool {
match self {
Self::Base2B
| Self::Base7B
| Self::Instruct2B
| Self::Instruct7B
| Self::InstructV1_1_2B
| Self::InstructV1_1_7B
| Self::CodeBase2B
| Self::CodeBase7B
| Self::CodeInstruct2B
| Self::CodeInstruct7B => true,
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
}
}
}
enum Model {
V1(Model1),
V2(Model2),
}
impl Model {
fn forward(&mut self, input_ids: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::V1(m) => m.forward(input_ids, pos),
Self::V2(m) => m.forward(input_ids, pos),
}
}
}
struct TextGeneration {
@ -232,7 +191,7 @@ struct Args {
repeat_last_n: usize,
/// The model to use.
#[arg(long, default_value = "2-2b")]
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
@ -280,10 +239,6 @@ fn main() -> Result<()> {
Which::CodeBase7B => "google/codegemma-7b".to_string(),
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
Which::BaseV2_2B => "google/gemma-2-2b".to_string(),
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
@ -308,6 +263,7 @@ fn main() -> Result<()> {
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
@ -317,15 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = if args.which.is_v1() {
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model1::new(args.use_flash_attn, &config, vb)?;
Model::V1(model)
} else {
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
};
let model = Model::new(args.use_flash_attn, &config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -1,77 +0,0 @@
* GLM4
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.
- [[https://github.com/THUDM/GLM4][Github]]
- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]]
** Running with ~cuda~
#+begin_src shell
cargo run --example glm4 --release --features cuda
#+end_src
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
#+end_src
** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换DFT的方法它广泛应用于信号处理、图像处理和数据分析等领域。
具体来说FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中我们通常需要知道信号的频率成分这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。
以下是使用 Python 中的 numpy 进行 FFT 的简单示例:
```python
import numpy as np
# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)
# 对该信号做FFT变换并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))
```
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
#+end_src
This example will read prompt from stdin
* Citation
#+begin_src
@misc{glm2024chatglm,
title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools},
author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
year={2024},
eprint={2406.12793},
archivePrefix={arXiv},
primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}
#+end_src
#+begin_src
@misc{wang2023cogvlm,
title={CogVLM: Visual Expert for Pretrained Language Models},
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
year={2023},
eprint={2311.03079},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
#+end_src

View File

@ -1,255 +0,0 @@
use candle_transformers::models::glm4::*;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
dtype: DType,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
dtype,
}
}
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
use std::io::Write;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");
let tokens = self.tokenizer.encode(line, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();
let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
self.model.reset_kv_cache(); // clean the cache
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,
#[arg(long)]
cpu: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 8192)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.2)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.6),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
let start = std::time::Instant::now();
let config = Config::glm4();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-granite LLMs from IBM Research
[Granite](https://www.ibm.com/granite) is a family of Large Language Models built for business, to help drive trust and scalability in AI-driven applications.
## Running the example
```bash
$ cargo run --example granite --features metal -r -- --model-type "granite7b-instruct" \
--prompt "Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind"
Explain how quantum computing differs from classical computing, focusing on key concepts like qubits, superposition, and entanglement. Describe two potential breakthroughs in the fields of drug discovery and cryptography. Offer a convincing argument for why businesses and governments should invest in quantum computing research now, emphasizing its future benefits and the risks of falling behind competitors.
In recent years, there has been significant interest in quantum computing due to its potential to revolutionize various fields, including drug discovery, cryptography, and optimization problems. Quantum computers, which leverage the principles of quantum mechanics, differ fundamentally from classical computers. Here are some of the key differences:
```
## Supported Models
There are two different modalities for the Granite family models: Language and Code.
### Granite for language
1. [Granite 7b Instruct](https://huggingface.co/ibm-granite/granite-7b-instruct)

View File

@ -1,251 +0,0 @@
// An implementation of different Granite models https://www.ibm.com/granite
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use candle_transformers::models::granite as model;
use model::{Granite, GraniteConfig};
use std::time::Instant;
const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "How Fault Tolerant Quantum Computers will help humanity?";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum GraniteModel {
Granite7bInstruct,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 10000)]
sample_len: usize,
/// Disable the key-value cache.
#[arg(long)]
no_kv_cache: bool,
/// The initial prompt.
#[arg(long)]
prompt: Option<String>,
/// Use different dtype than f16
#[arg(long)]
dtype: Option<String>,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long, default_value = "granite7b-instruct")]
model_type: GraniteModel,
#[arg(long)]
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (granite, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.model_type {
GraniteModel::Granite7bInstruct => "ibm-granite/granite-7b-instruct".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: GraniteConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
let filenames = match args.model_type {
GraniteModel::Granite7bInstruct => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(
Granite::load(vb, &config)?,
tokenizer_filename,
cache,
config,
)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config.eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::GraniteEosToks::Single)
});
let default_prompt = match args.model_type {
GraniteModel::Granite7bInstruct => DEFAULT_PROMPT,
};
let prompt = args.prompt.as_ref().map_or(default_prompt, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("Starting the inference loop:");
print!("{prompt}");
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
let use_cache_kv = cache.use_kv_cache;
(0..args.sample_len)
.inspect(|index| {
if *index == 1 {
start_gen = Instant::now();
}
})
.try_for_each(|index| -> Result<()> {
let (context_size, context_index) = if use_cache_kv && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = granite
.forward(&input, context_index, &mut cache)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
token_generated += 1;
tokens.push(next_token);
if let Some(model::GraniteEosToks::Single(eos_tok_id)) = eos_token_id {
if next_token == eos_tok_id {
return Err(E::msg("EOS token found"));
}
} else if let Some(model::GraniteEosToks::Multiple(ref eos_ids)) = eos_token_id {
if eos_ids.contains(&next_token) {
return Err(E::msg("EOS token found"));
}
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
Ok(())
})
.unwrap_or(());
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
token_generated,
(token_generated - 1) as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -35,26 +35,10 @@ enum Which {
V31,
V3Instruct,
V31Instruct,
V32_1b,
V32_1bInstruct,
V32_3b,
V32_3bInstruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
TinyLlama1_1BChat,
#[value(name = "SmoLM2-1B")]
SmolLM2_1B,
#[value(name = "SmoLM2-1B-Instruct")]
SmolLM2_1BInstruct,
#[value(name = "SmoLM2-360M")]
SmolLM2_360M,
#[value(name = "SmoLM2-360M-Instruct")]
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-135M")]
SmolLM2_135M,
#[value(name = "SmoLM2-135M-Instruct")]
SmolLM2_135MInstruct,
}
#[derive(Parser, Debug)]
@ -146,28 +130,15 @@ fn main() -> Result<()> {
};
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| {
let str = match args.which {
Which::V1 => "Narsil/amall-7b",
Which::V2 => "meta-llama/Llama-2-7b-hf",
Which::V3 => "meta-llama/Meta-Llama-3-8B",
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct",
Which::V31 => "meta-llama/Llama-3.1-8B",
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct",
Which::V32_1b => "meta-llama/Llama-3.2-1B",
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct",
Which::V32_3b => "meta-llama/Llama-3.2-3B",
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct",
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0",
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
};
str.to_string()
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
@ -185,22 +156,10 @@ fn main() -> Result<()> {
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::SmolLM2_360M
| Which::SmolLM2_360MInstruct
| Which::SmolLM2_135M
| Which::SmolLM2_135MInstruct
| Which::SmolLM2_1B
| Which::SmolLM2_1BInstruct
| Which::V32_1b
| Which::V32_1bInstruct
| Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

View File

@ -14,7 +14,6 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::llama::LlamaEosToks;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
@ -220,16 +219,9 @@ fn main() -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
match config.eos_token_id {
Some(LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
_ => (),
if Some(next_token) == config.eos_token_id {
break;
}
if rank == 0 {
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");

View File

@ -43,14 +43,6 @@ def import_protobuf(error_message=""):
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
if add_prefix_space:
prepend_scheme = "always"
if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
prepend_scheme = "first"
else:
prepend_scheme = "never"
return prepend_scheme
class SentencePieceExtractor:
"""
@ -527,15 +519,13 @@ class SpmConverter(Converter):
)
def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
def post_processor(self):
return None
def decoder(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
@ -646,8 +636,7 @@ class DebertaV2Converter(SpmConverter):
list_pretokenizers = []
if self.original_tokenizer.split_by_punct:
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
return pre_tokenizers.Sequence(list_pretokenizers)
def normalizer(self, proto):
@ -940,11 +929,10 @@ class PegasusConverter(SpmConverter):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
def pre_tokenizer(self, replacement, add_prefix_space):
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
)

View File

@ -1,20 +0,0 @@
# candle-mimi
[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
compression model using an encoder/decoder architecture with residual vector
quantization. The candle implementation supports streaming meaning that it's
possible to encode or decode a stream of audio tokens on the flight to provide
low latency interaction with an audio model.
## Running one example
Generating some audio tokens from an audio files.
```bash
wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
```
And decoding the audio tokens back into a sound file.
```bash
cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
```

View File

@ -1,274 +0,0 @@
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};
pub const SAMPLE_RATE: usize = 24_000;
pub(crate) struct AudioOutputData_ {
resampled_data: std::collections::VecDeque<f32>,
resampler: rubato::FastFixedIn<f32>,
output_buffer: Vec<f32>,
input_buffer: Vec<f32>,
input_len: usize,
}
impl AudioOutputData_ {
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
use rubato::Resampler;
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
let resampler = rubato::FastFixedIn::new(
resample_ratio,
f64::max(resample_ratio, 1.0),
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
Ok(Self {
resampled_data,
resampler,
input_buffer,
output_buffer,
input_len: 0,
})
}
pub fn reset(&mut self) {
use rubato::Resampler;
self.output_buffer.fill(0.);
self.input_buffer.fill(0.);
self.resampler.reset();
self.resampled_data.clear();
}
pub(crate) fn take_all(&mut self) -> Vec<f32> {
let mut data = Vec::with_capacity(self.resampled_data.len());
while let Some(elem) = self.resampled_data.pop_back() {
data.push(elem);
}
data
}
pub(crate) fn is_empty(&self) -> bool {
self.resampled_data.is_empty()
}
// Assumes that the input buffer is large enough.
fn push_input_buffer(&mut self, samples: &[f32]) {
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
self.input_len += samples.len()
}
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
use rubato::Resampler;
let mut pos_in = 0;
loop {
let rem = self.input_buffer.len() - self.input_len;
let pos_end = usize::min(pos_in + rem, samples.len());
self.push_input_buffer(&samples[pos_in..pos_end]);
pos_in = pos_end;
if self.input_len < self.input_buffer.len() {
break;
}
let (_, out_len) = self.resampler.process_into_buffer(
&[&self.input_buffer],
&mut [&mut self.output_buffer],
None,
)?;
for &elem in self.output_buffer[..out_len].iter() {
self.resampled_data.push_front(elem)
}
self.input_len = 0;
}
Ok(())
}
}
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio output stream!");
let host = cpal::default_host();
let device = host
.default_output_device()
.context("no output device available")?;
let mut supported_configs_range = device.supported_output_configs()?;
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
// On macOS, it's commonly the case that there are only stereo outputs.
None => device
.supported_output_configs()?
.next()
.context("no audio output available")?,
Some(config_range) => config_range,
};
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
let channels = config.channels as usize;
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
SAMPLE_RATE,
config.sample_rate.0 as usize,
)?));
let ad = audio_data.clone();
let stream = device.build_output_stream(
&config,
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
data.fill(0.);
let mut ad = ad.lock().unwrap();
let mut last_elem = 0f32;
for (idx, elem) in data.iter_mut().enumerate() {
if idx % channels == 0 {
match ad.resampled_data.pop_back() {
None => break,
Some(v) => {
last_elem = v;
*elem = v
}
}
} else {
*elem = last_elem
}
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
println!("Setup audio input stream!");
let host = cpal::default_host();
let device = host
.default_input_device()
.context("no input device available")?;
let mut supported_configs_range = device.supported_input_configs()?;
let config_range = supported_configs_range
.find(|c| c.channels() == 1)
.context("no audio input available")?;
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
config_range.min_sample_rate(),
config_range.max_sample_rate(),
);
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
println!(
"cpal device: {} {} {config:?}",
device.name().unwrap_or_else(|_| "unk".to_string()),
config.sample_rate.0
);
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
config.sample_rate.0 as usize,
SAMPLE_RATE,
)?));
let ad = audio_data.clone();
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mut ad = ad.lock().unwrap();
if let Err(err) = ad.push_samples(data) {
eprintln!("error processing audio input {err:?}")
}
},
move |err| eprintln!("cpal error: {err}"),
None, // None=blocking, Some(Duration)=timeout
)?;
stream.play()?;
Ok((stream, audio_data))
}
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
use rubato::Resampler;
let mut pcm_out =
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
let mut output_buffer = resampler.output_buffer_allocate(true);
let mut pos_in = 0;
while pos_in + resampler.input_frames_next() < pcm_in.len() {
let (in_len, out_len) =
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
pos_in += in_len;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
if pos_in < pcm_in.len() {
let (_in_len, out_len) = resampler.process_partial_into_buffer(
Some(&[&pcm_in[pos_in..]]),
&mut output_buffer,
None,
)?;
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
}
Ok(pcm_out)
}

View File

@ -1,165 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::mimi::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
mod audio_io;
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some mimi tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some mimi tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
/// to the encoder/decoder one at a time.
#[arg(long)]
streaming: Option<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::v0_1(None);
let mut model = Model::new(config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
codes.get("codes").expect("no codes in input file").clone()
}
Action::AudioToCode | Action::AudioToAudio => {
let pcm = if args.in_file == "-" {
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
let (stream, input_audio) = audio_io::setup_input_stream()?;
let mut pcms = vec![];
let stdin = std::thread::spawn(|| {
let mut s = String::new();
std::io::stdin().read_line(&mut s)
});
while !stdin.is_finished() {
let input = input_audio.lock().unwrap().take_all();
if input.is_empty() {
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
pcms.push(input)
}
drop(stream);
pcms.concat()
} else {
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
audio_io::resample(&pcm, sample_rate as usize, 24_000)?
} else {
pcm
}
};
match args.streaming {
Some(chunk_size) => {
let mut code_chunks = vec![];
for pcm in pcm.chunks(chunk_size) {
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
let code_chunk = model.encode(&pcm)?;
code_chunks.push(code_chunk)
}
Tensor::cat(&code_chunks, candle::D::Minus1)?
}
None => {
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
}
}
};
println!("codes shape: {:?}", codes.shape());
model.reset_state();
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = match args.streaming {
Some(chunk_size) => {
let seq_len = codes.dim(candle::D::Minus1)?;
let mut pcm_chunks = vec![];
for chunk_start in (0..seq_len).step_by(chunk_size) {
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
let pcm = model.decode_step(&codes.into())?;
if let Some(pcm) = pcm.as_option() {
pcm_chunks.push(pcm.clone())
}
}
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
}
None => model.decode(&codes)?,
};
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
if args.out_file == "-" {
let (stream, ad) = audio_io::setup_output_stream()?;
{
let mut ad = ad.lock().unwrap();
ad.push_samples(&pcm)?;
}
loop {
let ad = ad.lock().unwrap();
if ad.is_empty() {
break;
}
// That's very weird, calling thread::sleep here triggers the stream to stop
// playing (the callback doesn't seem to be called anymore).
// std::thread::sleep(std::time::Duration::from_millis(100));
}
drop(stream)
} else {
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
}
Ok(())
}

View File

@ -1,28 +0,0 @@
# candle-mobileclip
MobileCLIP is family of efficient CLIP-like models using FastViT-based image encoders.
See [MobileCLIP: Fast Image-Text Models through Multi-Modal Reinforced Training](https://arxiv.org/abs/2311.17049)
## Running on an example on cpu
```
$ cargo run --example mobileclip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "a cycling race","a photo of two cats","a robot holding a candle"
softmax_image_vec: [2.4819004e-5, 3.81081e-6, 0.9999714, 0.9999738, 2.382714e-5, 2.3317718e-6]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0025% Text: a cycling race
Probability: 0.0004% Text: a photo of two cats
Probability: 99.9971% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 99.9974% Text: a cycling race
Probability: 0.0024% Text: a photo of two cats
Probability: 0.0002% Text: a robot holding a candle
```

View File

@ -1,170 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::mobileclip;
use tokenizers::Tokenizer;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S1,
S2,
}
impl Which {
fn model_name(&self) -> String {
let name = match self {
Self::S1 => "S1",
Self::S2 => "S2",
};
format!("apple/MobileCLIP-{}-OpenCLIP", name)
}
fn config(&self) -> mobileclip::MobileClipConfig {
match self {
Self::S1 => mobileclip::MobileClipConfig::s1(),
Self::S2 => mobileclip::MobileClipConfig::s2(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
#[arg(value_enum, long, default_value_t=Which::S1)]
which: Which,
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
image_size,
&[0.0, 0.0, 0.0],
&[1.0, 1.0, 1.0],
)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_name = args.which.model_name();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};
let tokenizer = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = &args.which.config();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? }
};
let model = mobileclip::MobileClipModel::new(vb, config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {}", p, vec_seq[i]);
}
}
Ok(())
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
// let pad_id = *tokenizer
// .get_vocab(true)
// .get("<|endoftext|>")
// .ok_or(E::msg("No pad token"))?;
// The model does not work well if the text is padded using the <|endoftext|> token, using 0
// as the original OpenCLIP code.
let pad_id = 0;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -72,9 +72,8 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let image =
candle_examples::imagenet::load_image(args.image, args.which.resolution() as usize)?
.to_device(&device)?;
let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())?
.to_device(&device)?;
println!("loaded image {image:?}");
let model_file = match args.model {

View File

@ -284,11 +284,11 @@ impl MusicgenDecoder {
};
let embed_dim = cfg.vocab_size + 1;
let embed_tokens = (0..cfg.num_codebooks)
.map(|i| embedding(embed_dim, h, vb.pp(format!("embed_tokens.{i}"))))
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
.collect::<Result<Vec<_>>>()?;
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
let layers = (0..cfg.num_hidden_layers)
.map(|i| MusicgenDecoderLayer::load(vb.pp(format!("layers.{i}")), cfg))
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
Ok(Self {
@ -341,7 +341,7 @@ impl MusicgenForCausalLM {
let h = cfg.hidden_size;
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
let lm_heads = (0..cfg.num_codebooks)
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(format!("lm_heads.{i}"))))
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
decoder,

View File

@ -1,28 +0,0 @@
# PaliGemma
[HuggingFace Model Card](https://huggingface.co/google/paligemma-3b-pt-224) -
[Model Page](https://ai.google.dev/gemma/docs/paligemma)
```bash
cargo run --features cuda --release --example paligemma -- \
--prompt "caption fr" --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
```
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0]
loaded the model in 1.267744448s
caption fr. Un groupe de cyclistes qui sont dans la rue.
13 tokens generated (56.52 token/s)
```
```bash
cargo run --features cuda --release --example paligemma -- \
--prompt "caption fr" --image candle-examples/examples/flux/assets/flux-robot.jpg
```
```
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0]
loaded the model in 1.271492621s
caption fr une image d' un robot sur la plage avec le mot rouillé
15 tokens generated (62.78 token/s)
```

View File

@ -1,276 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::paligemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
image: Tensor,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
image: Tensor,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
image,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = if index > 0 {
self.model.forward(&input)?
} else {
self.model.setup(&self.image, &input)?
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
image: String,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "google/paligemma-3b-mix-224".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config = Config::paligemma_3b_224();
let image = load_image(&args.image, config.vision_config.image_size)?
.to_device(&device)?
.to_dtype(dtype)?
.unsqueeze(0)?;
println!("loaded image with shape {:?}", image);
let start = std::time::Instant::now();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
image,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
let prompt = format!("{}\n", args.prompt);
pipeline.run(&prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,23 +0,0 @@
# candle-parler-tts
[Parler-TTS](https://huggingface.co/parler-tts/parler-tts-large-v1) is a large
text-to-speech model with 2.2B parameters trained on ~45K hours of audio data.
The voice can be controlled by a text prompt.
## Run an example
```bash
cargo run --example parler-tts -r -- \
--prompt "Hey, how are you doing today?"
```
In order to specify some prompt for the voice, use the `--description` argument.
```bash
cargo run --example parler-tts -r -- \
--prompt "Hey, how are you doing today?" \
--description "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
```
https://github.com/user-attachments/assets/1b16aeac-70a3-4803-8589-4563279bba33

View File

@ -1,206 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::parler_tts::{Config, Model};
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
#[arg(
long,
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
)]
description: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.0)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,
#[arg(long, default_value_t = 5000)]
sample_len: usize,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
quantized: bool,
/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long, default_value_t = 512)]
max_steps: usize,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long, default_value = "large-v1")]
which: Which,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "large-v1")]
LargeV1,
#[value(name = "mini-v1")]
MiniV1,
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::LargeV1 => "parler-tts/parler-tts-large-v1".to_string(),
Which::MiniV1 => "parler-tts/parler-tts-mini-v1".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => match args.which {
Which::MiniV1 => vec![repo.get("model.safetensors")?],
Which::LargeV1 => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
let mut model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let description_tokens = tokenizer
.encode(args.description, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
let prompt_tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
let lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
args.top_p,
);
println!("starting generation...");
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
println!("generated codes\n{codes}");
let codes = codes.to_dtype(DType::I64)?;
codes.save_safetensors("codes", "out.safetensors")?;
let codes = codes.unsqueeze(0)?;
let pcm = model
.audio_encoder
.decode_codes(&codes.to_device(&device)?)?;
println!("{pcm}");
let pcm = pcm.i((0, 0))?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;
Ok(())
}

View File

@ -1,28 +0,0 @@
# pixtral
Pixtral-12B is a 12B text+vision model.
[Blog Post](https://mistral.ai/news/pixtral-12b/) -
[HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) -
[HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b).
```bash
cargo run --profile=release-with-debug --features cuda --example pixtral -- \
--image candle-examples/examples/flux/assets/flux-robot.jpg
```
```
Describe the image.
The image depicts a charming, rustic robot standing on a sandy beach at sunset.
The robot has a vintage, steampunk aesthetic with visible gears and mechanical
parts. It is holding a small lantern in one hand, which emits a warm glow, and
its other arm is extended forward as if reaching out or guiding the way. The
robot's body is adorned with the word "RUST" in bright orange letters, adding to
its rustic theme.
The background features a dramatic sky filled with clouds, illuminated by the
setting sun, casting a golden hue over the scene. Gentle waves lap against the
shore, creating a serene and picturesque atmosphere. The overall mood of the
image is whimsical and nostalgic, evoking a sense of adventure and tranquility.
```

View File

@ -1,327 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::pixtral::{vision_model, Config, Model};
use candle::{DType, Device, Module, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
image: Tensor,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
image: Tensor,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
image,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let get_token = |v| match self.tokenizer.get_token(v) {
Some(token) => Ok(token),
None => anyhow::bail!("cannot find the {v} token"),
};
let bos_token = get_token("<s>")?;
let eos_token = get_token("</s>")?;
let inst_token = get_token("[INST]")?;
let end_inst_token = get_token("[/INST]")?;
let img_break = get_token("[IMG_BREAK]")?;
let img_end = get_token("[IMG_END]")?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let logits = if index > 0 {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
self.model.lm_forward(&input)?
} else {
let (_b, _c, h, w) = self.image.dims4()?;
let h = h / self.model.patch_size;
let w = w / self.model.patch_size;
let image_embeds = self.model.encode_image(&self.image)?;
println!("generated image embeddings {image_embeds:?}");
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let break_embeds = {
let input = Tensor::new(&[img_break], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let start_embeds = {
let mut in_tokens = vec![bos_token, inst_token];
in_tokens.extend_from_slice(tokens.as_slice());
let input = Tensor::new(in_tokens.as_slice(), &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let end_embeds = {
let input =
Tensor::new(&[img_end, end_inst_token], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let mut input_embeds = vec![start_embeds];
for h_idx in 0..h {
if h_idx > 0 {
input_embeds.push(break_embeds.clone())
}
let row = image_embeds.narrow(1, h_idx * w, w)?;
input_embeds.push(row);
}
input_embeds.push(end_embeds);
let input_embeds = Tensor::cat(&input_embeds, 1)?;
self.model.lm_forward_embeds(&input_embeds)?
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Describe the image.\n")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
image: String,
#[arg(long)]
vision_only: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "mistral-community/pixtral-12b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let device = candle_examples::device(args.cpu)?;
let dtype = if device.supports_bf16() && !args.vision_only {
DType::BF16
} else {
DType::F32
};
let config: Config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let image = if args.image.ends_with(".safetensors") {
match candle::safetensors::load(&args.image, &device)?.remove("img") {
None => anyhow::bail!("no img tensor in {}", args.image),
Some(v) => v,
}
} else {
candle_examples::imagenet::load_image_with_std_mean(
&args.image,
1024,
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.261_302_6, 0.275_777_1],
)?
};
let image = image.to_device(&device)?.unsqueeze(0)?;
println!("loaded image with shape {:?}", image);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
if args.vision_only {
let start = std::time::Instant::now();
let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?;
println!("loaded the model in {:?}", start.elapsed());
let embs = model.forward(&image)?;
println!("EMBS\n{embs}");
} else {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
image,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
}
Ok(())
}

View File

@ -1,24 +0,0 @@
## SigLIP
SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss,
[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224).
### Running an example
```
$ cargo run --features cuda -r --example siglip -
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 100.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -1,153 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::siglip;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = siglip::Config::base_patch16_224();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = siglip::Model::new(&config, vb)?;
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
config: &siglip::Config,
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = config.text_config.pad_token_id;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = config.text_config.max_position_embeddings;
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -1,12 +0,0 @@
# silero-vad: Voice Activity Detection
[Silero VAD (v5)](https://github.com/snakers4/silero-vad) detects voice activity in streaming audio.
This example uses the models available in the hugging face [onnx-community/silero-vad](https://huggingface.co/onnx-community/silero-vad).
## Running the example
```bash
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
```

View File

@ -1,199 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::Parser;
use candle::{DType, Tensor};
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "silero")]
Silero,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum SampleRate {
#[value(name = "8000")]
Sr8k,
#[value(name = "16000")]
Sr16k,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
input: Option<String>,
#[arg(long)]
sample_rate: SampleRate,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// The model to use.
#[arg(long, default_value = "silero")]
which: Which,
}
/// an iterator which reads consecutive frames of le i16 values from a reader
struct I16Frames<R> {
rdr: R,
buf: Box<[u8]>,
len: usize,
eof: bool,
}
impl<R> I16Frames<R> {
fn new(rdr: R, frame_size: usize) -> Self {
I16Frames {
rdr,
buf: vec![0; frame_size * std::mem::size_of::<i16>()].into_boxed_slice(),
len: 0,
eof: false,
}
}
}
impl<R: std::io::Read> Iterator for I16Frames<R> {
type Item = std::io::Result<Vec<f32>>;
fn next(&mut self) -> Option<Self::Item> {
if self.eof {
return None;
}
self.len += match self.rdr.read(&mut self.buf[self.len..]) {
Ok(0) => {
self.eof = true;
0
}
Ok(n) => n,
Err(e) => return Some(Err(e)),
};
if self.eof || self.len == self.buf.len() {
let buf = self.buf[..self.len]
.chunks(2)
.map(|bs| match bs {
[a, b] => i16::from_le_bytes([*a, *b]),
_ => unreachable!(),
})
.map(|i| i as f32 / i16::MAX as f32)
.collect();
self.len = 0;
Some(Ok(buf))
} else {
self.next()
}
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let start = std::time::Instant::now();
let model_id = match &args.model_id {
Some(model_id) => std::path::PathBuf::from(model_id),
None => match args.which {
Which::Silero => hf_hub::api::sync::Api::new()?
.model("onnx-community/silero-vad".into())
.get("onnx/model.onnx")?,
// TODO: candle-onnx doesn't support Int8 dtype
// Which::SileroQuantized => hf_hub::api::sync::Api::new()?
// .model("onnx-community/silero-vad".into())
// .get("onnx/model_quantized.onnx")?,
},
};
let (sample_rate, frame_size, context_size): (i64, usize, usize) = match args.sample_rate {
SampleRate::Sr8k => (8000, 256, 32),
SampleRate::Sr16k => (16000, 512, 64),
};
println!("retrieved the files in {:?}", start.elapsed());
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let model = candle_onnx::read_file(model_id)?;
println!("loaded the model in {:?}", start.elapsed());
let start = std::time::Instant::now();
struct State {
frame_size: usize,
sample_rate: Tensor,
state: Tensor,
context: Tensor,
}
let mut state = State {
frame_size,
sample_rate: Tensor::new(sample_rate, &device)?,
state: Tensor::zeros((2, 1, 128), DType::F32, &device)?,
context: Tensor::zeros((1, context_size), DType::F32, &device)?,
};
let mut res = vec![];
for chunk in I16Frames::new(std::io::stdin().lock(), state.frame_size) {
let chunk = chunk.unwrap();
if chunk.len() < state.frame_size {
continue;
}
let next_context = Tensor::from_slice(
&chunk[state.frame_size - context_size..],
(1, context_size),
&device,
)?;
let chunk = Tensor::from_vec(chunk, (1, state.frame_size), &device)?;
let chunk = Tensor::cat(&[&state.context, &chunk], 1)?;
let inputs = std::collections::HashMap::from_iter([
("input".to_string(), chunk),
("sr".to_string(), state.sample_rate.clone()),
("state".to_string(), state.state.clone()),
]);
let out = candle_onnx::simple_eval(&model, inputs).unwrap();
let out_names = &model.graph.as_ref().unwrap().output;
let output = out.get(&out_names[0].name).unwrap().clone();
state.state = out.get(&out_names[1].name).unwrap().clone();
assert_eq!(state.state.dims(), &[2, 1, 128]);
state.context = next_context;
let output = output.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(output.len(), 1);
let output = output[0];
println!("vad chunk prediction: {output}");
res.push(output);
}
println!("calculated prediction in {:?}", start.elapsed());
let res_len = res.len() as f32;
let prediction = res.iter().sum::<f32>() / res_len;
println!("vad average prediction: {prediction}");
Ok(())
}

View File

@ -1,28 +0,0 @@
# candle-splade
SPLADE is a neural retrieval model which learns query/document sparse expansion via the BERT MLM head and sparse regularization. Sparse representations benefit from several advantages compared to dense approaches: efficient use of inverted index, explicit lexical match, interpretability... They also seem to be better at generalizing on out-of-domain data. In this example we can do the following two tasks:
- Compute sparse embedding for a given query.
- Compute similarities between a set of sentences using sparse embeddings.
## Sparse Sentence embeddings
SPLADE is used to compute the sparse embedding for a given query. The model weights
are downloaded from the hub on the first run. This makes use of the BertForMaskedLM model.
```bash
cargo run --example splade --release -- --prompt "Here is a test sentence"
> "the out there still house inside position outside stay standing hotel sitting dog animal sit bird cat statue cats"
> [0.10270107, 0.269471, 0.047469813, 0.0016636598, 0.05394874, 0.23105666, 0.037475716, 0.45949644, 0.009062732, 0.06790692, 0.0327835, 0.33122346, 0.16863061, 0.12688516, 0.340983, 0.044972017, 0.47724655, 0.01765311, 0.37331146]
```
```bash
cargo run --example splade --release --features
> score: 0.47 'The new movie is awesome' 'The new movie is so great'
> score: 0.43 'The cat sits outside' 'The cat plays in the garden'
> score: 0.14 'I love pasta' 'Do you like pizza?'
> score: 0.11 'A man is playing guitar' 'The cat plays in the garden'
> score: 0.05 'A man is playing guitar' 'A woman watches TV'
```

View File

@ -1,210 +0,0 @@
use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{self, BertForMaskedLM, Config};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,
// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,
// Path to the config file.
#[arg(long)]
config_file: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "prithivida/Splade_PP_en_v1".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
}
},
},
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let dtype = bert::DTYPE;
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device).unwrap() }
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap()
};
let model = BertForMaskedLM::load(vb, &config)?;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids, None)?;
let vec = Tensor::log(
&Tensor::try_from(1.0)?
.to_dtype(dtype)?
.to_device(&device)?
.broadcast_add(&ys.relu()?)?,
)?
.max(1)?;
let vec = normalize_l2(&vec)?;
let vec = vec.squeeze(0)?.to_vec1::<f32>()?;
let indices = (0..vec.len())
.filter(|&i| vec[i] != 0.0)
.map(|x| x as u32)
.collect::<Vec<_>>();
let tokens = tokenizer.decode(&indices, true).unwrap();
println!("{tokens:?}");
let values = indices.iter().map(|&i| vec[i as usize]).collect::<Vec<_>>();
println!("{values:?}");
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Ok(Tensor::new(tokens.as_slice(), &device)?)
})
.collect::<Result<Vec<_>>>()?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Ok(Tensor::new(tokens.as_slice(), &device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
let vector = Tensor::log(
&Tensor::try_from(1.0)?
.to_dtype(dtype)?
.to_device(&device)?
.broadcast_add(&ys.relu()?)?,
)?;
let vector = vector
.broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(dtype)?)?
.max(1)?;
let vec = normalize_l2(&vector)?;
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = vec.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = vec.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -1,71 +0,0 @@
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3/3.5
![](assets/stable-diffusion-3.jpg)
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*, generated by Stable Diffusion 3 Medium
Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.
- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- [research paper](https://arxiv.org/pdf/2403.03206)
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)
Stable Diffusion 3.5 is a family of text-to-image models with latest improvements:
- [announcement blog post](https://stability.ai/news/introducing-stable-diffusion-3-5)
It has three variants:
- [Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) @ 8.1b params, with scaled and slightly modified MMDiT architecture.
- [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) distilled version that enables 4-step inference.
- [Stable Diffusion 3.5 Medium](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) @ 2.5b params, with improved MMDiT-X architecture.
## Getting access to the weights
The weights of Stable Diffusion 3/3.5 is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the repos on HuggingFace Hub to gain access to the weights for your HuggingFace account.
To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli):
```shell
huggingface-cli login
```
and you will be prompted to enter your token.
On the first run, the weights will be automatically downloaded from the Huggingface Hub. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally.
## Running the model
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- \
--which 3-medium --height 1024 --width 1024 \
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
```
To use different models, changed the value of `--which` option. (Possible values: `3-medium`, `3.5-large`, `3.5-large-turbo` and `3.5-medium`).
To display other options available,
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- --help
```
If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`.
```shell
cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ...
```
## Performance Benchmark
Below benchmark is done with Stable Diffusion 3 Medium by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).
System specs (Desktop PCIE 5 x8/x8 dual-GPU setup):
- Operating System: Ubuntu 23.10
- CPU: i9 12900K w/o overclocking.
- RAM: 64G dual-channel DDR5 @ 4800 MT/s
| Speed (iter/s) | w/o flash-attn | w/ flash-attn |
| -------------- | -------------- | ------------- |
| RTX 3090 Ti | 0.83 | 2.15 |
| RTX 4090 | 1.72 | 4.06 |

Binary file not shown.

Before

Width:  |  Height:  |  Size: 81 KiB

View File

@ -1,234 +0,0 @@
use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5};
use std::path::PathBuf;
use tokenizers::tokenizer::Tokenizer;
struct ClipWithTokenizer {
clip: stable_diffusion::clip::ClipTextTransformer,
config: stable_diffusion::clip::Config,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl ClipWithTokenizer {
fn new(
vb: candle_nn::VarBuilder,
config: stable_diffusion::clip::Config,
tokenizer_path: &str,
max_position_embeddings: usize,
) -> Result<Self> {
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
let path_buf = hf_hub::api::sync::Api::new()?
.model(tokenizer_path.to_string())
.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
))?)
.map_err(E::msg)?;
Ok(Self {
clip,
config,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let pad_id = match &self.config.pad_with {
Some(padding) => *self
.tokenizer
.get_vocab(true)
.get(padding.as_str())
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
None => *self
.tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
};
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let eos_position = tokens.len() - 1;
while tokens.len() < self.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let (text_embeddings, text_embeddings_penultimate) = self
.clip
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
Ok((text_embeddings_penultimate, text_embeddings_pooled))
}
}
struct T5WithTokenizer {
t5: t5::T5EncoderModel,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl T5WithTokenizer {
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok(Self {
t5: model,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<Tensor> {
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.t5.forward_dt(&input_token_ids, Some(DType::F32))?;
Ok(embeddings)
}
}
pub struct StableDiffusion3TripleClipWithTokenizer {
clip_l: ClipWithTokenizer,
clip_g: ClipWithTokenizer,
clip_g_text_projection: candle_nn::Linear,
t5: T5WithTokenizer,
}
impl StableDiffusion3TripleClipWithTokenizer {
pub fn new_split(
clip_g_file: &PathBuf,
clip_l_file: &PathBuf,
t5xxl_file: &PathBuf,
device: &candle::Device,
) -> Result<Self> {
let vb_clip_g = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)?
};
let vb_clip_l = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
};
let vb_t5 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F16, device)?
};
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_clip_l,
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?;
let clip_g = ClipWithTokenizer::new(
vb_clip_g,
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn new(vb: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let clip_g = ClipWithTokenizer::new(
vb.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb.pp("clip_g.transformer.text_projection"))?;
let t5 = T5WithTokenizer::new(vb.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let (clip_l_embeddings, clip_l_embeddings_pooled) =
self.clip_l.encode_text_to_embedding(prompt, device)?;
let (clip_g_embeddings, clip_g_embeddings_pooled) =
self.clip_g.encode_text_to_embedding(prompt, device)?;
let clip_g_embeddings_pooled = self
.clip_g_text_projection
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
.squeeze(0)?;
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
.unsqueeze(0)?;
let clip_embeddings_concat = Tensor::cat(
&[&clip_l_embeddings, &clip_g_embeddings],
D::Minus1,
)?
.pad_with_zeros(D::Minus1, 0, 2048)?;
let t5_embeddings = self
.t5
.encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
Ok((context, y))
}
}

View File

@ -1,273 +0,0 @@
mod clip;
mod sampling;
mod vae;
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
use crate::clip::StableDiffusion3TripleClipWithTokenizer;
use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
use anyhow::{Ok, Result};
use clap::Parser;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3-medium")]
V3Medium,
#[value(name = "3.5-large")]
V3_5Large,
#[value(name = "3.5-large-turbo")]
V3_5LargeTurbo,
#[value(name = "3.5-medium")]
V3_5Medium,
}
impl Which {
fn is_3_5(&self) -> bool {
match self {
Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true,
}
}
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A cute rusty robot holding a candle torch in its hand, \
with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \
bright background, high quality, 4k"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Use flash_attn to accelerate attention operation in the MMDiT.
#[arg(long)]
use_flash_attn: bool,
/// The height in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
height: usize,
/// The width in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
width: usize,
/// The model to use.
#[arg(long, default_value = "3-medium")]
which: Which,
/// The seed to use when generating random samples.
#[arg(long)]
num_inference_steps: Option<usize>,
/// CFG scale.
#[arg(long)]
cfg_scale: Option<f64>,
/// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,
/// Use Skip Layer Guidance (SLG) for the sampling.
/// Currently only supports Stable Diffusion 3.5 Medium.
#[arg(long)]
use_slg: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
tracing,
use_flash_attn,
height,
width,
num_inference_steps,
cfg_scale,
time_shift,
seed,
which,
use_slg,
} = Args::parse();
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(cpu)?;
let default_inference_steps = match which {
Which::V3_5Large => 28,
Which::V3_5LargeTurbo => 4,
Which::V3_5Medium => 28,
Which::V3Medium => 28,
};
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which {
Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0,
Which::V3_5Medium => 4.0,
Which::V3Medium => 4.0,
};
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
let api = hf_hub::api::sync::Api::new()?;
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo_for_text_encoders = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
// Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually
// placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo.
// To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors
// under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions
// to get the monolithic text encoders. This is not a trivial task.
// Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium,
// which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors.
// so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository.
// TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders.
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let sai_repo_for_mmdit = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = {
let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3_5Medium => "sd3.5_medium.safetensors",
Which::V3Medium => unreachable!(),
};
sai_repo_for_mmdit.get(model_file)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file,
&clip_l_file,
&t5xxl_file,
&device,
)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
};
match which {
Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb),
Which::V3Medium => unreachable!(),
}
} else {
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new(vb.pp("text_encoders"))?;
(MMDiTConfig::sd3_medium(), triple, vb)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
// Drop the text model early to avoid using too much memory.
drop(triple);
let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let slg_config = if use_slg {
match which {
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
scale: 2.5,
start: 0.01,
end: 0.2,
layers: vec![7, 8, 9],
}),
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
}
} else {
None
};
let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new(
&mmdit_config,
use_flash_attn,
vb.pp("model.diffusion_model"),
)?;
sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
slg_config,
)?
};
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);
let img = {
let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
};
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Ok(())
}

View File

@ -1,83 +0,0 @@
use anyhow::{Ok, Result};
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::flux;
use candle_transformers::models::mmdit::model::MMDiT;
pub struct SkipLayerGuidanceConfig {
pub scale: f64,
pub start: f64,
pub end: f64,
pub layers: Vec<usize>,
}
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
mmdit: &MMDiT,
y: &Tensor,
context: &Tensor,
num_inference_steps: usize,
cfg_scale: f64,
time_shift: f64,
height: usize,
width: usize,
slg_config: Option<SkipLayerGuidanceConfig>,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
.map(|x| x as f64 / num_inference_steps as f64)
.rev()
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();
for (step, window) in sigmas.windows(2).enumerate() {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
None,
)?;
let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
if let Some(slg_config) = slg_config.as_ref() {
if (num_inference_steps as f64) * slg_config.start < (step as f64)
&& (step as f64) < (num_inference_steps as f64) * slg_config.end
{
let slg_noise_pred = mmdit.forward(
&x,
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
&y.i(..1)?,
&context.i(..1)?,
Some(&slg_config.layers),
)?;
guidance = (guidance
+ (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
}
}
x = (x + (guidance * (*s_prev - *s_curr))?)?;
}
Ok(x)
}
// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper
// https://arxiv.org/pdf/2403.03206
// Following the implementation in ComfyUI:
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/
// comfy/model_sampling.py#L181
fn time_snr_shift(alpha: f64, t: f64) -> f64 {
alpha * t / (1.0 + (alpha - 1.0) * t)
}
fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> {
Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)?
- ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?)
}

View File

@ -1,93 +0,0 @@
use anyhow::{Ok, Result};
use candle_transformers::models::stable_diffusion::vae;
pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
let config = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 16,
norm_num_groups: 32,
use_quant_conv: false,
use_post_quant_conv: false,
};
Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
}
pub fn sd3_vae_vb_rename(name: &str) -> String {
let parts: Vec<&str> = name.split('.').collect();
let mut result = Vec::new();
let mut i = 0;
while i < parts.len() {
match parts[i] {
"down_blocks" => {
result.push("down");
}
"mid_block" => {
result.push("mid");
}
"up_blocks" => {
result.push("up");
match parts[i + 1] {
// Reverse the order of up_blocks.
"0" => result.push("3"),
"1" => result.push("2"),
"2" => result.push("1"),
"3" => result.push("0"),
_ => {}
}
i += 1; // Skip the number after up_blocks.
}
"resnets" => {
if i > 0 && parts[i - 1] == "mid_block" {
match parts[i + 1] {
"0" => result.push("block_1"),
"1" => result.push("block_2"),
_ => {}
}
i += 1; // Skip the number after resnets.
} else {
result.push("block");
}
}
"downsamplers" => {
result.push("downsample");
i += 1; // Skip the 0 after downsamplers.
}
"conv_shortcut" => {
result.push("nin_shortcut");
}
"attentions" => {
if parts[i + 1] == "0" {
result.push("attn_1")
}
i += 1; // Skip the number after attentions.
}
"group_norm" => {
result.push("norm");
}
"query" => {
result.push("q");
}
"key" => {
result.push("k");
}
"value" => {
result.push("v");
}
"proj_attn" => {
result.push("proj_out");
}
"conv_norm_out" => {
result.push("norm_out");
}
"upsamplers" => {
result.push("upsample");
i += 1; // Skip the 0 after upsamplers.
}
part => result.push(part),
}
i += 1;
}
result.join(".")
}

View File

@ -1,45 +0,0 @@
# candle-stella-en-v5: Implementation of [stella_en_1.5B_v5](https://huggingface.co/dunzhang/stella_en_1.5B_v5) embedding model
As of 7th Oct 2024, *Stella_en_1.5B_v5* is one of the top ranking model on `retrieval` and `reranking` tasks in [MTEB](https://huggingface.co/spaces/mteb/leaderboard) leaderboard.
[Model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) on the HuggingFace Hub.
## Running the example
Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
> Tensor[[1, 1024], f32]
```
Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling multiple embedding dimensions.
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
```bash
$ cargo run --example stella-en-v5 --release --features <metal | cuda>
>
> Score: 0.8178786
> Query: What are some ways to reduce stress?
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
> stress from building up.
>
>
> Score: 0.7853528
> Query: What are the benefits of drinking green tea?
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
```
## Supported options:
- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.

View File

@ -1,359 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::path::Path;
use anyhow::{anyhow, Error as E, Result};
use clap::Parser;
use candle_transformers::models::stella_en_v5::{
Config, EmbedDim as StellaEmbedDim, EmbeddingModel,
};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo};
use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
struct Embedding {
model: EmbeddingModel,
device: Device,
tokenizer: Tokenizer,
}
impl Embedding {
fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self {
Self {
model,
tokenizer,
device: device.clone(),
}
}
fn encode(&mut self, task: EncodeTask, text: Option<String>) -> Result<()> {
// Just shocasing embeddings, this has no real value
if let Some(text) = text {
let qry = task.query_preproc(&[text]);
let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?;
let shape = (1, encoding.len());
let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?;
let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?;
let result = self.model.forward(&input, &mask)?;
println!("embeddings: {result}");
} else {
// Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers)
let queries = [
"What are some ways to reduce stress?".to_string(),
"What are the benefits of drinking green tea?".to_string(),
];
let docs = [
"There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(),
"Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(),
];
// We only encode the queries and not the data
let qry = task.query_preproc(&queries);
let mut qry_encoded = self
.tokenizer
.encode_batch(qry, true)
.map_err(|e| anyhow!(e))?;
let mut docs_encoded = self
.tokenizer
.encode_batch(docs.to_vec(), true)
.map_err(|e| anyhow!(e))?;
let qry_embed = {
// Now, we generate the tensors for the `input` and `mask`
let shape = (qry_encoded.len(), qry_encoded[1].len());
let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?;
let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?;
for (i, e) in qry_encoded.drain(..).enumerate() {
let input_id =
Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?;
let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)?
.to_dtype(DType::U8)?
.unsqueeze(0)?;
ids =
ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?;
masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?;
}
// Let's generate the embeddings for the query, we are going to be normalizing the result.
// For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data
self.model.forward_norm(&ids, &masks)?
};
let doc_embed = {
let shape = (docs_encoded.len(), docs_encoded[1].len());
let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?;
let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?;
for (i, e) in docs_encoded.drain(..).enumerate() {
let input_id =
Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?;
let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)?
.to_dtype(DType::U8)?
.unsqueeze(0)?;
ids =
ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?;
masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?;
}
// Let's generate the embeddings for the query, we are going to be normalizing the result.
// For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data
self.model.forward_norm(&ids, &masks)?
};
println!(
"Embed shapes:\nQuery: {:?}\nDocs: {:?}",
qry_embed.shape(),
doc_embed.shape()
); // [2, 1024] for head dim `1024`
// a matmul to generate the `similarity` score
let res = qry_embed.matmul(&doc_embed.t()?)?;
for (k, v) in queries.iter().enumerate() {
let tnsr = res.get(k)?;
let max = tnsr.argmax(0)?.to_scalar::<u32>()?;
println!(
"\nScore: {}\nQuery: {}\nAnswer: {}\n\n",
tnsr.get(max as usize)?.to_scalar::<f32>()?,
v,
docs[k]
);
}
}
Ok(())
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum EmbedDim {
#[value(name = "256")]
Dim256,
#[value(name = "768")]
Dim768,
#[value(name = "1024")]
Dim1024,
#[value(name = "2048")]
Dim2048,
#[value(name = "4096")]
Dim4096,
#[value(name = "6144")]
Dim6144,
#[value(name = "8192")]
Dim8192,
}
impl EmbedDim {
/// Returns dir path to the embed head weights int he repo
pub fn embed_dim_default_dir(&self) -> &'static str {
match self {
Self::Dim256 => "2_Dense_256",
Self::Dim768 => "2_Dense_768",
Self::Dim1024 => "2_Dense_1024",
Self::Dim2048 => "2_Dense_2048",
Self::Dim4096 => "2_Dense_4096",
Self::Dim6144 => "2_Dense_6144",
Self::Dim8192 => "2_Dense_8192",
}
}
/// Resolves the `EmbedDim` for given variant
pub fn embed_dim(&self) -> StellaEmbedDim {
match self {
Self::Dim256 => StellaEmbedDim::Dim256,
Self::Dim768 => StellaEmbedDim::Dim768,
Self::Dim1024 => StellaEmbedDim::Dim1024,
Self::Dim2048 => StellaEmbedDim::Dim2048,
Self::Dim4096 => StellaEmbedDim::Dim4096,
Self::Dim6144 => StellaEmbedDim::Dim6144,
Self::Dim8192 => StellaEmbedDim::Dim8192,
}
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
pub enum EncodeTask {
/// `s2p` is the `retrieval` task
/// Default in this example
#[value(name = "s2p")]
S2P,
/// `s2s` is the semantic similarity task
#[value(name = "s2s")]
S2S,
}
impl EncodeTask {
/// Preprocess a set of inputs basef on a template suggested by the model authors
/// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction
pub fn query_preproc(&self, txt: &[String]) -> Vec<String> {
let instruct = match self {
Self::S2P => {
"Given a web search query, retrieve relevant passages that answer the query."
}
Self::S2S => "Retrieve semantically similar text.",
};
txt.iter()
.map(|s| format!("Instruct: {instruct}\nQuery: {s}"))
.collect::<Vec<_>>()
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
query: Option<String>,
#[arg(long, default_value = "1024")]
embed_dim: Option<EmbedDim>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
base_weight_files: Option<String>,
#[arg(long)]
embed_head_weight_files: Option<String>,
/// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5)
/// `s2s`: Semantic textual similarity
/// `s2p`: Retrieval task - `Default` in this example
#[arg(long, default_value = "s2p")]
task: Option<EncodeTask>,
}
// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
Ok(tokenizer)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let start = std::time::Instant::now();
let api = Api::new()?;
let embed_dim = match args.embed_dim {
Some(d) => d,
None => EmbedDim::Dim1024,
};
let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
// Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights
// E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo
let base_weight_files = match args.base_weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
let embed_weight_files = match args.embed_head_weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir());
vec![repo.get(&head_w_path)?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = DType::F32;
let base_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? };
// Embedding layer is always built on F32 for accuracy
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
let model = EmbeddingModel::new(
&Config::new_1_5_b_v5(embed_dim.embed_dim()),
base_vb,
embed_vb,
)?;
println!("loaded the model in {:?}", start.elapsed());
let mut embedding = Embedding::new(model, tokenizer, &device);
let task = args.task.map_or(EncodeTask::S2P, |t| t);
embedding.encode(task, args.query)
}

View File

@ -10,6 +10,7 @@ use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use std::iter;
use tokenizers::Tokenizer;
mod multilingual;
@ -17,6 +18,7 @@ mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex};
pub enum Model {
Normal(m::model::Whisper),
@ -389,7 +391,6 @@ enum WhichModel {
Large,
LargeV2,
LargeV3,
LargeV3Turbo,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
@ -406,7 +407,6 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::LargeV3Turbo
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
@ -427,7 +427,6 @@ impl WhichModel {
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
@ -480,10 +479,6 @@ struct Args {
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
/// The input device to use.
#[arg(long)]
device: Option<String>,
}
pub fn main() -> Result<()> {
@ -548,12 +543,13 @@ pub fn main() -> Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
let mut decoder = Decoder::new(
let language_token = None;
let mut dc = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
/* language_token */ None,
language_token,
args.task,
args.timestamps,
args.verbose,
@ -569,69 +565,47 @@ pub fn main() -> Result<()> {
// Set up the input device and stream with the default input config.
let host = cpal::default_host();
let audio_device = match args.device.as_ref() {
None => host.default_input_device(),
Some(device) => host
.input_devices()?
.find(|x| x.name().map_or(false, |y| &y == device)),
let _device = "default";
let _device = if _device == "default" {
host.default_input_device()
} else {
host.input_devices()?
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
}
.expect("failed to find the audio input device");
.expect("failed to find input device");
let audio_config = audio_device
let _config = _device
.default_input_config()
.expect("Failed to get default input config");
println!("audio config {audio_config:?}");
let channel_count = audio_config.channels() as usize;
let in_sample_rate = audio_config.sample_rate().0 as usize;
let resample_ratio = 16000. / in_sample_rate as f64;
let mut resampler = rubato::FastFixedIn::new(
resample_ratio,
10.,
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let (tx, rx) = std::sync::mpsc::channel();
let stream = audio_device.build_input_stream(
&audio_config.config(),
move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
let pcm = pcm
.iter()
.step_by(channel_count)
.copied()
.collect::<Vec<f32>>();
if !pcm.is_empty() {
tx.send(pcm).unwrap()
}
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
let channel_count = _config.channels() as usize;
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
let audio_ring_buffer_2 = audio_ring_buffer.clone();
std::thread::spawn(move || loop {
let data = record_audio(&_device, &_config, 300).unwrap();
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
let max_len = data.len() * 16;
let data_len = data.len();
let len = audio_ring_buffer.lock().unwrap().len();
if len > max_len {
let mut data = audio_ring_buffer.lock().unwrap();
let new_data = data[data_len..].to_vec();
*data = new_data;
}
});
// loop to process the audio data forever (until the user stops the program)
println!("transcribing audio...");
let mut buffered_pcm = vec![];
let mut language_token_set = false;
while let Ok(pcm) = rx.recv() {
use rubato::Resampler;
buffered_pcm.extend_from_slice(&pcm);
if buffered_pcm.len() < 10 * in_sample_rate {
continue;
}
let mut resampled_pcm = vec![];
for buffered_pcm in buffered_pcm.chunks(1024) {
let pcm = resampler.process(&[&buffered_pcm], None)?;
resampled_pcm.extend_from_slice(&pcm[0])
}
let pcm = resampled_pcm;
println!("{} {}", buffered_pcm.len(), pcm.len());
buffered_pcm.clear();
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
println!("Transcribing audio...");
for (i, _) in iter::repeat(()).enumerate() {
std::thread::sleep(std::time::Duration::from_millis(1000));
let data = audio_ring_buffer_2.lock().unwrap().clone();
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
@ -640,13 +614,9 @@ pub fn main() -> Result<()> {
)?;
// on the first iteration, we detect the language and set the language token.
if !language_token_set {
if i == 0 {
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
(true, None) => Some(multilingual::detect_language(
decoder.model(),
&tokenizer,
&mel,
)?),
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
@ -657,12 +627,47 @@ pub fn main() -> Result<()> {
}
};
println!("language_token: {:?}", language_token);
decoder.set_language_token(language_token);
language_token_set = true;
dc.set_language_token(language_token);
}
decoder.run(&mel, None)?;
decoder.reset_kv_cache();
dc.run(
&mel,
Some((
i as f64,
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
)),
)?;
dc.reset_kv_cache();
}
Ok(())
}
fn record_audio(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
milliseconds: u64,
) -> Result<Vec<i16>> {
let writer = Arc::new(Mutex::new(Vec::new()));
let writer_2 = writer.clone();
let stream = device.build_input_stream(
&config.config(),
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let processed = data
.iter()
.map(|v| (v * 32768.0) as i16)
.collect::<Vec<i16>>();
writer_2.lock().unwrap().extend_from_slice(&processed);
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
drop(stream);
let data = writer.lock().unwrap().clone();
let step = 3;
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
Ok(data)
}

View File

@ -12,7 +12,7 @@ file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/sample
from the hub.
```bash
cargo run --example whisper --release --features="symphonia"
cargo run --example whisper --release
> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }

View File

@ -370,7 +370,6 @@ enum WhichModel {
Large,
LargeV2,
LargeV3,
LargeV3Turbo,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
@ -389,7 +388,6 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::LargeV3Turbo
| Self::DistilLargeV2
| Self::DistilLargeV3 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
@ -411,7 +409,6 @@ impl WhichModel {
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"),

View File

@ -123,7 +123,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
let padding = if pad != 0 { (size - 1) / 2 } else { 0 };
let (bn, bias) = match b.parameters.get("batch_normalize") {
Some(p) if p.parse::<usize>()? != 0 => {
let bn = batch_norm(filters, 1e-5, vb.pp(format!("batch_norm_{index}")))?;
let bn = batch_norm(filters, 1e-5, vb.pp(&format!("batch_norm_{index}")))?;
(Some(bn), false)
}
Some(_) | None => (None, true),
@ -135,9 +135,9 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
dilation: 1,
};
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
} else {
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
};
let leaky = match activation {
"leaky" => true,

View File

@ -161,7 +161,7 @@ impl C2f {
let cv2 = ConvBlock::load(vb.pp("cv2"), (2 + n) * c, c2, 1, 1, None)?;
let mut bottleneck = Vec::with_capacity(n);
for idx in 0..n {
let b = Bottleneck::load(vb.pp(format!("bottleneck.{idx}")), c, c, shortcut)?;
let b = Bottleneck::load(vb.pp(&format!("bottleneck.{idx}")), c, c, shortcut)?;
bottleneck.push(b)
}
Ok(Self {

View File

@ -1,42 +1,23 @@
use candle::{Device, Result, Tensor};
pub const IMAGENET_MEAN: [f32; 3] = [0.485f32, 0.456, 0.406];
pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
/// Loads an image from disk using the image crate at the requested resolution,
/// using the given std and mean parameters.
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
p: P,
res: usize,
mean: &[f32; 3],
std: &[f32; 3],
) -> Result<Tensor> {
/// Loads an image from disk using the image crate at the requested resolution.
// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
let img = image::ImageReader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(
res as u32,
res as u32,
image::imageops::FilterType::Triangle,
);
.resize_to_fill(res, res, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (res, res, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean = Tensor::new(mean, &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(std, &Device::Cpu)?.reshape((3, 1, 1))?;
let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)?
.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
/// Loads an image from disk using the image crate at the requested resolution.
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: usize) -> Result<Tensor> {
load_image_with_std_mean(p, res, &IMAGENET_MEAN, &IMAGENET_STD)
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 224, 224). imagenet normalization is applied.
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.7.2"
version = "0.6.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.2" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.7.2"
version = "0.6.0"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -70,9 +70,10 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
template <typename T>
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) {
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float2 mean_var = make_float2(0.f, 0.f);
@ -133,9 +134,10 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta,
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) {
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float tmp = 0.0f; // partial sum for thread in warp
@ -528,15 +530,15 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#define RMSNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const int n_cols, const int block_size, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, block_size, eps); \
const int n_cols, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, block_size, eps); \
const TYPENAME *beta, const int n_cols, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.7.2"
version = "0.6.0"
edition = "2021"
description = "Metal kernels for Candle"
@ -17,12 +17,9 @@ thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
clap = { version = "4.2.4", features = ["derive"] }
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
anyhow = "1"
rand = "0.8.5"
rand_distr = "0.4.3"

View File

@ -1,136 +0,0 @@
use anyhow::Result;
use candle_metal_kernels::GemmDType;
/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
use clap::{Parser, Subcommand};
use half::f16;
fn run_gemm(f32: bool, n: usize) -> Result<()> {
const WARMUP_ITERS: usize = 2;
const MIN_DUR: f64 = 4.;
let device = metal::Device::system_default().unwrap();
let (b, m, n, k) = (1, n, n, n);
let kernels = candle_metal_kernels::Kernels::new();
let command_queue = device.new_command_queue();
let options = metal::MTLResourceOptions::StorageModeManaged;
let (lhs, rhs) = if f32 {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&rhs) as u64,
options,
);
(lhs, rhs)
} else {
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(&rhs) as u64,
options,
);
(lhs, rhs)
};
let (dtype, name, sizeof) = if f32 {
(GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
} else {
(GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
};
let output = device.new_buffer((b * m * n * sizeof) as u64, options);
for mlx in [false, true] {
let mut sum_dt = 0f64;
let mut iters = 0usize;
for idx in 0.. {
let command_buffer = command_queue.new_command_buffer();
let start_time = std::time::Instant::now();
if mlx {
candle_metal_kernels::call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
} else {
candle_metal_kernels::call_gemm(
&device,
command_buffer,
&kernels,
name,
(b, m, n, k),
&[m * k, k, 1],
0,
&lhs,
&[n * k, n, 1],
0,
&rhs,
&output,
)?;
}
command_buffer.commit();
command_buffer.wait_until_completed();
let dt = start_time.elapsed().as_secs_f64();
if idx < WARMUP_ITERS {
continue;
}
sum_dt += dt;
iters += 1;
if sum_dt > MIN_DUR {
break;
}
}
let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
let mlx = if mlx { "MLX" } else { "MFA" };
println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
}
Ok(())
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Gemm,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// The benchmark to be run.
#[command(subcommand)]
task: Task,
}
fn main() -> Result<()> {
let args = Args::parse();
match args.task {
Task::Gemm => {
for f32 in [false, true] {
for n in [512, 1024, 2048, 4096] {
run_gemm(f32, n)?;
}
}
}
}
Ok(())
}

View File

@ -105,7 +105,7 @@ kernel void FN_NAME##_strided( \
return; \
} \
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
} \

View File

@ -1,39 +0,0 @@
#include <metal_stdlib>
using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
}
#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \
fill_with<T>(out, value, numel, tid); \
} \
#define FILL_OPS(NAME, T) \
FILL_OP(NAME, T) \
FILL_OPS(u8, uchar)
FILL_OPS(u32, uint)
FILL_OPS(i64, long)
FILL_OPS(f16, half)
FILL_OPS(f32, float)
#if __METAL_VERSION__ >= 310
FILL_OPS(bf16, bfloat)
#endif

View File

@ -0,0 +1,921 @@
//
// GEMM.metal
// MetalFlashAttention
//
// Created by Philip Turner on 6/23/23.
//
#include <metal_stdlib>
#ifndef __METAL_SIMDGROUP_EVENT
#define __METAL_SIMDGROUP_EVENT
// Invoking the generation of LLVM bitcode for async copies.
//
// %struct._simdgroup_event_t = type opaque
//
struct _simdgroup_event_t;
// Invoking the generation of LLVM bitcode for async copies.
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_1d(
ulong, ulong, threadgroup void *, const device void *, ulong)
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
// Invoking the generation of LLVM bitcode for async copies.
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_1d(
ulong, ulong, device void *, const threadgroup void *, ulong)
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
// i64, i64,
// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>,
// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>,
// <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong,
threadgroup void *, ulong, ulong, ulong2,
const device void *, ulong, ulong, ulong2,
long2, int)
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: argmemonly convergent nounwind
// declare %struct._simdgroup_event_t*
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
// i64, i64,
// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>,
// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>,
// <2 x i64>, i32)
// local_unnamed_addr #4
//
thread _simdgroup_event_t*
__metal_simdgroup_async_copy_2d(
ulong, ulong,
device void *, ulong, ulong, ulong2,
const threadgroup void *, ulong, ulong, ulong2,
long2, int)
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
// Invoking the generation of LLVM bitcode for async copies.
//
// ; Function Attrs: convergent nounwind
// declare void
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
// local_unnamed_addr #3
//
void __metal_wait_simdgroup_events(
int, thread _simdgroup_event_t**)
__asm("air.wait_simdgroup_events");
#pragma METAL internals : enable
namespace metal
{
enum class simdgroup_async_copy_clamp_mode {
clamp_to_zero = 0,
clamp_to_edge = 1
};
struct simdgroup_event {
METAL_FUNC simdgroup_event() thread {}
template <typename T>
METAL_FUNC void async_copy(
threadgroup T *dst,
const device T *src,
ulong n_elements
) thread {
event = __metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(T),
alignof(T),
// Description of the arguments.
reinterpret_cast<threadgroup void *>(dst),
reinterpret_cast<const device void *>(src),
n_elements);
}
template <typename T>
METAL_FUNC void async_copy(
device T *dst,
const threadgroup T *src,
ulong n_elements
) thread {
event = __metal_simdgroup_async_copy_1d(
// Description of the data type.
sizeof(T),
alignof(T),
// Description of the arguments.
reinterpret_cast<device void *>(dst),
reinterpret_cast<const threadgroup void *>(src),
n_elements);
}
template <typename T>
METAL_FUNC void async_copy(
// Description of the destination.
threadgroup T *dst,
ushort dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const device T *src,
uint src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false,
simdgroup_async_copy_clamp_mode clamp_mode =
simdgroup_async_copy_clamp_mode::clamp_to_zero
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(T),
alignof(T),
// Description of the destination.
reinterpret_cast<threadgroup void *>(dst),
ushort(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const device void *>(src),
uint(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
static_cast<int>(clamp_mode));
}
template <typename T>
METAL_FUNC void async_copy(
// Description of the destination.
device T *dst,
uint dst_elements_per_row,
ushort2 dst_tile_dimensions,
// Description of the source.
const threadgroup T *src,
ushort src_elements_per_row,
ushort2 src_tile_dimensions,
// Other arguments.
bool transpose_matrix = false
) thread {
if (transpose_matrix) {
src_tile_dimensions = src_tile_dimensions.yx;
dst_tile_dimensions = dst_tile_dimensions.yx;
}
event = __metal_simdgroup_async_copy_2d(
// Description of the data type.
sizeof(T),
alignof(T),
// Description of the destination.
reinterpret_cast<device void *>(dst),
uint(dst_elements_per_row),
1,
ulong2(dst_tile_dimensions),
// Description of the source.
reinterpret_cast<const threadgroup void *>(src),
ushort(src_elements_per_row),
1,
ulong2(src_tile_dimensions),
// Other arguments.
long2(0),
0);
}
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
__metal_wait_simdgroup_events(
count, reinterpret_cast<thread _simdgroup_event_t**>(events));
}
private:
// Invoking the generation of LLVM bitcode for async copies.
//
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
//
thread _simdgroup_event_t* event;
};
} // namespace metal
#pragma METAL internals : disable
#endif
// -*- Metal -*-
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
// Copyright (c) 2023 Philip Turner. See MIT LICENSE
//===----------------------------------------------------------------------===//
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
#define __METAL_SIMDGROUP_MATRIX_STORAGE
// Contains C++ symbols accessible to a developer through automatic code
// completion in Xcode 14.2. Formatted with the same style as the Metal Standard
// Library for consistency with other Metal code.
#if defined(__HAVE_SIMDGROUP_MATRIX__)
#pragma METAL internals : enable
namespace metal
{
template <typename T>
struct simdgroup_matrix_storage {
typedef vec<T, 64> storage_type;
storage_type t;
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
return reinterpret_cast<thread vec<T, 2>*>(&t);
}
METAL_FUNC simdgroup_matrix_storage() thread = default;
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
*(this->thread_elements()) = thread_elements;
}
METAL_FUNC static ushort2 offset(ushort thread_index_in_simdgroup) {
// https://patents.google.com/patent/US11256518B2
ushort lane_id = thread_index_in_simdgroup;
ushort quad_id = lane_id / 4;
constexpr ushort QUADRANT_SPAN_M = 4;
constexpr ushort THREADS_PER_QUADRANT = 8;
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
return ushort2(N_in_simd, M_in_simd);
}
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
} else {
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
}
}
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
} else {
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
}
}
// WARNING: All load and store functions assume the X dimension is divisible by 2.
METAL_FUNC void load(const device T *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y], src[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const device vec<T, 2>*>(src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x);
}
}
METAL_FUNC void load(const threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
*(thread_elements()) = vec<T, 2>(src[matrix_origin.x * elements_per_row + matrix_origin.y], src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y]);
} else {
*(thread_elements()) = *reinterpret_cast<const threadgroup vec<T, 2>*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
}
}
METAL_FUNC void load_first(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][0] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][0] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void load_second(const device T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements()[0][1] = src[matrix_origin.x * elements_per_row + matrix_origin.y];
} else {
thread_elements()[0][1] = src[matrix_origin.y * elements_per_row + matrix_origin.x];
}
}
METAL_FUNC void store(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
dst[ulong((matrix_origin.x + 1) * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<device vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
METAL_FUNC void store_first(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][0];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][0];
}
}
METAL_FUNC void store_second(device T *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[ulong(matrix_origin.x * elements_per_row) + matrix_origin.y] = thread_elements()[0][1];
} else {
dst[matrix_origin.y * elements_per_row + matrix_origin.x] = thread_elements()[0][1];
}
}
METAL_FUNC void store(threadgroup T *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = thread_elements()[0][0];
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = thread_elements()[0][1];
} else {
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements());
}
}
template <typename U, typename V>
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
if (!accumulate) {
*(thread_elements()) = vec<T, 2>(0);
}
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
}
// 'bfloat' is 'float' with the lower 16 bits set to garbage (BF15).
METAL_FUNC thread ushort4* thread_elements_bfloat() thread {
thread float2* elements = thread_elements();
return reinterpret_cast<thread ushort4*>(elements);
}
METAL_FUNC simdgroup_matrix_storage<float> unpack_bfloat() thread {
ushort4 output;
thread ushort2& elements = thread_elements();
output.y = elements[0];
output.w = elements[1];
return simdgroup_matrix_storage(as_type<float2>(output));
}
METAL_FUNC simdgroup_matrix_storage<ushort> pack_bfloat() thread {
thread ushort4* elements = thread_elements_bfloat();
return simdgroup_matrix_storage(ushort2(elements->y, elements->w));
}
METAL_FUNC void load_bfloat(const threadgroup ushort *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
thread_elements_bfloat()->y = src[matrix_origin.x * elements_per_row + matrix_origin.y];
thread_elements_bfloat()->w = src[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y];
} else {
thread_elements_bfloat()->zw = *reinterpret_cast<const threadgroup ushort2*>(src + matrix_origin.y * elements_per_row + matrix_origin.x);
thread_elements_bfloat()->y = thread_elements_bfloat()->z;
}
}
METAL_FUNC void store_bfloat(threadgroup ushort *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
if (transpose_matrix) {
dst[matrix_origin.x * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).y;
dst[(matrix_origin.x + 1) * elements_per_row + matrix_origin.y] = *(thread_elements_bfloat()).w;
} else {
*(thread_elements_bfloat()).z = *(thread_elements_bfloat()).y;
*reinterpret_cast<threadgroup vec<T, 2>*>(dst + matrix_origin.y * elements_per_row + matrix_origin.x) = *(thread_elements_bfloat()).zw;
}
}
};
} // namespace metal
#pragma METAL internals : disable
#endif
#endif // __METAL_SIMDGROUP_MATRIX_STORAGE
using namespace metal;
// MARK: - Function Constants
// Dimensions of each matrix.
constant uint M [[function_constant(0)]];
constant uint N [[function_constant(1)]];
constant uint K [[function_constant(2)]];
// Whether each matrix is transposed.
constant bool A_trans [[function_constant(10)]];
constant bool B_trans [[function_constant(11)]];
constant bool D_trans [[function_constant(13)]];
constant uint A_leading_dim = A_trans ? M : K;
constant uint B_leading_dim = B_trans ? K : N;
// Alpha and beta constants from BLAS.
constant float alpha [[function_constant(20)]];
constant float beta [[function_constant(21)]];
constant bool batched [[function_constant(100)]];
constant bool fused_activation [[function_constant(101)]];
constant bool fused_bias [[function_constant(50001)]]; // 102
constant bool use_bias = is_function_constant_defined(fused_bias) ? fused_bias : false;
constant bool use_activation_function = fused_activation && !fused_bias;
constant bool use_activation = use_bias || use_activation_function;
constant bool batched_activation_function = batched && use_activation_function;
constant ushort M_simd [[function_constant(200)]];
constant ushort N_simd [[function_constant(201)]];
constant ushort K_simd [[function_constant(202)]];
// Elide work on the edge when matrix dimension < SRAM block dimension.
constant ushort M_modulo = (M % M_simd == 0) ? M_simd : (M % M_simd);
constant ushort N_modulo = (N % N_simd == 0) ? N_simd : (N % N_simd);
constant ushort M_padded = (M < M_simd) ? (M_modulo + 7) / 8 * 8 : M_simd;
constant ushort N_padded = (N < N_simd) ? (N_modulo + 7) / 8 * 8 : N_simd;
constant ushort M_splits [[function_constant(210)]];
constant ushort N_splits [[function_constant(211)]];
constant ushort M_group = M_simd * M_splits;
constant ushort N_group = N_simd * N_splits;
constant ushort A_block_leading_dim = (A_trans ? M_group : K_simd);
constant ushort B_block_leading_dim = (B_trans ? K_simd : N_group);
// There is no padding for M reads/writes.
// There is no padding for N reads/writes.
constant ushort K_simd_unpadded = (K % K_simd == 0) ? K_simd : (K % K_simd);
constant ushort K_simd_padded = (K_simd_unpadded + 7) / 8 * 8;
constant ushort A_sram_length = (M_simd / 8) * 1;
constant ushort B_sram_length = 1 * (N_simd / 8);
constant ushort A_block_length = M_group * K_simd;
// Threadgroup block must fit entire C accumulator and partial sums.
constant ushort A_sram_offset = 0;
constant ushort B_sram_offset = A_sram_offset + A_sram_length;
constant ushort C_sram_offset = B_sram_offset + B_sram_length;
constant ushort A_block_offset = 0;
constant ushort B_block_offset = A_block_offset + A_block_length;
// MARK: - Utilities
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* A_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
// A_sram[M_simd][8]
return sram + A_sram_offset + (matrix_origin.y / 8) * (8 / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* B_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
// A_sram[8][N_simd]
return sram + B_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC thread simdgroup_matrix_storage<T>* C_sram(thread simdgroup_matrix_storage<T> *sram, ushort2 matrix_origin) {
// C_sram[M_simd][N_simd]
return sram + C_sram_offset + (matrix_origin.y / 8) * (N_simd / 8) + (matrix_origin.x / 8);
}
template <typename T>
METAL_FUNC void prefetch(threadgroup T *A_block, device T *A,
ushort2 A_tile_src, uint2 A_offset,
threadgroup T *B_block, device T *B,
ushort2 B_tile_src, uint2 B_offset, uint k)
{
A_tile_src.x = min(uint(K_simd), K - k);
B_tile_src.y = min(uint(K_simd), K - k);
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
auto B_src = simdgroup_matrix_storage<T>::apply_offset(B, B_leading_dim, B_offset, B_trans);
// Rounded-up ceiling for the threadgroup block.
const uint K_edge_floor = K - K_simd_unpadded;
const uint K_edge_ceil = K_edge_floor + K_simd_padded;
ushort K_padded;
if (K_edge_floor == K_simd) {
K_padded = K_simd;
} else {
K_padded = min(uint(K_simd), K_edge_ceil - k);
}
ushort2 A_tile_dst(K_padded, A_tile_src.y);
ushort2 B_tile_dst(B_tile_src.x, K_padded);
simdgroup_event events[2];
events[0].async_copy(A_block, A_block_leading_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
events[1].async_copy(B_block, B_block_leading_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
simdgroup_event::wait(2, events);
}
// One iteration of the MACC loop, effectively k=8 iterations.
template <typename T>
METAL_FUNC void multiply_accumulate(thread simdgroup_matrix_storage<T> *sram,
const threadgroup T *A_block,
const threadgroup T *B_block,
bool accumulate = true)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
ushort2 origin(0, m);
A_sram(sram, origin)->load(A_block, A_block_leading_dim, origin, A_trans);
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, 0);
B_sram(sram, origin)->load(B_block, B_block_leading_dim, origin, B_trans);
}
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto A = A_sram(sram, ushort2(0, m));
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto B = B_sram(sram, ushort2(n, 0));
auto C = C_sram(sram, ushort2(n, m));
C->multiply(*A, *B, accumulate);
}
}
}
template <typename T>
METAL_FUNC void partial_store(thread simdgroup_matrix_storage<T> *sram,
threadgroup T *C_block, bool is_k_summation)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
if (is_k_summation) {
C_sram(sram, origin)->store(C_block, N_simd, origin);
} else {
C_sram(sram, origin)->store(C_block, N_group, origin);
}
}
}
}
template <typename T>
METAL_FUNC void partial_accumulate(thread simdgroup_matrix_storage<T> *sram,
threadgroup T *C_block, bool is_k_summation)
{
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
auto B = B_sram(sram, ushort2(n, 0));
if (is_k_summation) {
B->load(C_block, N_simd, origin);
} else {
B->load(C_block, N_group, origin);
}
}
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
ushort2 origin(n, m);
auto B = B_sram(sram, ushort2(n, 0));
auto C = C_sram(sram, origin);
if (is_k_summation) {
C->thread_elements()[0] += B->thread_elements()[0];
} else {
float2 C_old = float2(B->thread_elements()[0]);
float2 C_new = float2(C->thread_elements()[0]);
C->thread_elements()[0] = vec<T, 2>(fast::fma(C_old, beta, C_new));
}
}
}
}
template <typename T>
METAL_FUNC void async_access_accumulator(threadgroup T *C_block, device T *C,
uint2 C_offset, bool is_store)
{
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
min(uint(M_group), M - C_offset.y));
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, C_offset);
if (is_store) {
simdgroup_event event;
event.async_copy(C_src, N, C_tile, C_block, N_group, C_tile);
simdgroup_event::wait(1, &event);
} else {
simdgroup_event event;
event.async_copy(C_block, N_group, C_tile, C_src, N, C_tile);
simdgroup_event::wait(1, &event);
}
}
template <typename T>
METAL_FUNC void store_accumulator(thread simdgroup_matrix_storage<T> *sram,
device T *C, bool m_is_edge, bool n_is_edge)
{
const ushort m_start = (m_is_edge) ? M_modulo : 0;
const ushort n_start = (n_is_edge) ? N_modulo : 0;
const ushort m_end = (m_is_edge) ? M_simd : M_modulo;
const ushort n_end = (n_is_edge) ? N_simd : N_modulo;
#pragma clang loop unroll(full)
for (ushort m = m_start; m < m_end; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = n_start; n < n_end; n += 8) {
ushort2 origin(n, m);
C_sram(sram, origin)->store(C, N, origin);
}
}
}
template <typename T>
struct activation_functor {
using function = void(threadgroup T *C,
device void *D,
uint grid_index_in_batch,
uint2 matrix_origin,
ushort2 tile_dimensions,
ushort lane_id);
typedef visible_function_table<function> function_table;
};
// MARK: - Kernels
template <typename T>
void _gemm_impl(device T *A [[buffer(0)]],
device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
device void *D [[buffer(3), function_constant(use_activation)]],
threadgroup T *threadgroup_block [[threadgroup(0)]],
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
typename activation_functor<T>::function_table table [[buffer(11), function_constant(use_activation_function)]],
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
if (batched) {
// TODO: Re-compute every inner loop iteration for FP64 accumulate.
ulong3 offsets = matrix_offsets[0].xyz * gid.z;
A = (device T*)((device uchar*)A + offsets[0]);
B = (device T*)((device uchar*)B + offsets[1]);
C = (device T*)((device uchar*)C + offsets[2]);
}
simdgroup_matrix_storage<T> sram[1024];
auto A_block = threadgroup_block + A_block_offset;
auto B_block = threadgroup_block + B_block_offset;
ushort2 sid(sidx % N_splits, sidx / N_splits);
ushort2 offset_in_simd = simdgroup_matrix_storage<T>::offset(lane_id);
uint2 A_offset(0, gid.y * M_group);
uint2 B_offset(gid.x * N_group, 0);
{
uint C_base_offset_x = B_offset.x + sid.x * N_simd;
uint C_base_offset_y = A_offset.y + sid.y * M_simd;
if (C_base_offset_x >= N || C_base_offset_y >= M) {
return;
}
}
ushort2 offset_in_group(sid.x * N_simd + offset_in_simd.x,
sid.y * M_simd + offset_in_simd.y);
if (use_bias) {
if (sidx == 0) {
auto bias = (device T*)D;
if (batched) {
ulong offset = matrix_offsets[gid.z].w;
bias = (device T*)((device uchar*)bias + offset);
}
ushort bias_elements;
if (is_function_constant_defined(D_trans) && D_trans) {
bias += A_offset.y;
bias_elements = min(uint(M_group), M - A_offset.y);
} else {
bias += B_offset.x;
bias_elements = min(uint(N_group), N - B_offset.x);
}
simdgroup_event event;
event.async_copy(threadgroup_block, bias, bias_elements);
simdgroup_event::wait(1, &event);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (is_function_constant_defined(D_trans) && D_trans) {
auto bias = threadgroup_block + offset_in_group.y;
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto D = bias[m];
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto C = C_sram(sram, ushort2(n, m));
*(C->thread_elements()) = D;
}
}
} else {
auto bias = threadgroup_block + offset_in_group.x;
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
auto D = *(threadgroup vec<T, 2>*)(bias + n);
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
auto C = C_sram(sram, ushort2(n, m));
*(C->thread_elements()) = D;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
ushort2 A_tile_src;
ushort2 B_tile_src;
if (sidx == 0) {
A_tile_src.y = min(uint(M_group), M - A_offset.y);
B_tile_src.x = min(uint(N_group), N - B_offset.x);
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, 0);
}
if (K > K_simd && !use_bias) {
#pragma clang loop unroll(full)
for (ushort m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (ushort n = 0; n < N_padded; n += 8) {
*C_sram(sram, ushort2(n, m)) = simdgroup_matrix_storage<T>(0);
}
}
}
for (uint K_floor = 0; K_floor < K; K_floor += K_simd) {
ushort2 A_block_offset(offset_in_simd.x, offset_in_group.y);
ushort2 B_block_offset(offset_in_group.x, offset_in_simd.y);
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_block_leading_dim, A_block_offset, A_trans);
auto B_block_src = simdgroup_matrix_storage<T>::apply_offset(B_block, B_block_leading_dim, B_block_offset, B_trans);
threadgroup_barrier(mem_flags::mem_threadgroup);
#pragma clang loop unroll(full)
for (ushort k = 0; k < K_simd_padded; k += 8) {
bool accumulate = use_bias || !(K <= K_simd && k == 0);
multiply_accumulate(sram, A_block_src, B_block_src, accumulate);
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
if (K_floor + K_simd < K) {
#pragma clang loop unroll(full)
for (ushort k = K_simd_padded; k < K_simd; k += 8) {
multiply_accumulate(sram, A_block_src, B_block_src);
A_block_src += A_trans ? 8 * A_block_leading_dim : 8;
B_block_src += B_trans ? 8 : 8 * B_block_leading_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
uint K_next = K_floor + K_simd;
A_offset.x = K_next;
B_offset.y = K_next;
prefetch(A_block, A, A_tile_src, A_offset, B_block, B, B_tile_src, B_offset, K_next);
}
}
}
if (alpha != 1) {
#pragma clang loop unroll(full)
for (int m = 0; m < M_padded; m += 8) {
#pragma clang loop unroll(full)
for (int n = 0; n < N_padded; n += 8) {
C_sram(sram, ushort2(n, m))->thread_elements()[0] *= T(alpha);
}
}
}
uint2 C_offset(B_offset.x, A_offset.y);
ushort2 C_block_offset = offset_in_group.xy;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (beta != 0) {
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
partial_accumulate(sram, C_block, false);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (use_activation_function) {
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
partial_store(sram, C_block, false);
simdgroup_barrier(mem_flags::mem_threadgroup);
uint grid_index_in_batch = (batched ? gid.z : 0);
uint2 matrix_origin = C_offset + uint2(C_block_offset);
matrix_origin &= ~7;
ushort2 tile_dimensions(min(uint(N_group), N - matrix_origin.x),
min(uint(M_group), M - matrix_origin.y));
uint function_index = 0;
if (batched_activation_function) {
function_index = activation_function_offsets[gid.z];
}
table[function_index](C_block, D, grid_index_in_batch, matrix_origin, tile_dimensions, lane_id);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, true);
}
return;
} else if ((M % 8 != 0) || (N % 8 != 0)) {
auto C_block = simdgroup_matrix_storage<T>::apply_offset(threadgroup_block, N_group, C_block_offset);
partial_store(sram, C_block, false);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sidx == 0) {
async_access_accumulator(threadgroup_block, C, C_offset, true);
}
} else {
uint2 matrix_origin = C_offset + uint2(C_block_offset);
auto C_src = simdgroup_matrix_storage<T>::apply_offset(C, N, matrix_origin);
store_accumulator(sram, C_src, false, false);
const uint M_edge_floor = M - M % M_simd;
const uint N_edge_floor = N - N % N_simd;
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, false);
}
if (matrix_origin.x < N_edge_floor) {
store_accumulator(sram, C_src, false, true);
if (matrix_origin.y < M_edge_floor) {
store_accumulator(sram, C_src, true, true);
}
}
}
}
kernel void hgemm(device half *A [[buffer(0)]],
device half *B [[buffer(1)]],
device half *C [[buffer(2)]],
device void *D [[buffer(3), function_constant(use_activation)]],
threadgroup half *threadgroup_block [[threadgroup(0)]],
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
typename activation_functor<half>::function_table table [[buffer(11), function_constant(use_activation_function)]],
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
_gemm_impl<half>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
}
kernel void sgemm(device float *A [[buffer(0)]],
device float *B [[buffer(1)]],
device float *C [[buffer(2)]],
device void *D [[buffer(3), function_constant(use_activation)]],
threadgroup float *threadgroup_block [[threadgroup(0)]],
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
typename activation_functor<float>::function_table table [[buffer(11), function_constant(use_activation_function)]],
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
_gemm_impl<float>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
}
#if defined(__HAVE_BFLOAT__)
kernel void bgemm(
device bfloat *A [[buffer(0)]],
device bfloat *B [[buffer(1)]],
device bfloat *C [[buffer(2)]],
device void *D [[buffer(3), function_constant(use_activation)]],
threadgroup bfloat *threadgroup_block [[threadgroup(0)]],
device ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
typename activation_functor<bfloat>::function_table table [[buffer(11), function_constant(use_activation_function)]],
constant uint *activation_function_offsets [[buffer(12), function_constant(batched_activation_function)]],
uint3 gid [[threadgroup_position_in_grid]],
ushort sidx [[simdgroup_index_in_threadgroup]],
ushort lane_id [[thread_index_in_simdgroup]])
{
_gemm_impl<bfloat>(A, B, C, D, threadgroup_block, matrix_offsets, table, activation_function_offsets, gid, sidx, lane_id);
}
#endif

View File

@ -6,42 +6,39 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
pub mod utils;
mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const RANDOM: &str = include_str!("random.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
// const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const GEMM: &str = include_str!("gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Binary,
Cast,
Conv,
Fill,
Gemm,
Indexing,
Mfa,
Quantized,
Random,
Reduce,
Sort,
Ternary,
Unary,
Binary,
Ternary,
Cast,
Reduce,
Mfa,
Conv,
Random,
Quantized,
Sort,
}
pub mod copy2d {
@ -195,19 +192,17 @@ impl Kernels {
fn get_library_source(&self, source: Source) -> &'static str {
match source {
Source::Affine => AFFINE,
Source::Binary => BINARY,
Source::Cast => CAST,
Source::Conv => CONV,
Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
Source::Random => RANDOM,
Source::Reduce => REDUCE,
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Mfa => panic!("Invalid lib"),
Source::Binary => BINARY,
Source::Ternary => TERNARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Sort => SORT,
Source::Mfa => GEMM,
}
}
@ -222,22 +217,14 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let lib = match source {
Source::Mfa => {
let source_data = MFA;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
};
let compile_options = CompileOptions::new();
compile_options.set_fast_math_enabled(true);
let source_content = self.get_library_source(source);
let lib = device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
libraries.insert(source, lib.clone());
Ok(lib)
}
@ -2184,201 +2171,5 @@ pub fn call_arg_sort(
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum GemmDType {
BF16,
F16,
F32,
}
#[allow(clippy::too_many_arguments)]
pub fn call_mlx_gemm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct GemmParams {
m: i32,
n: i32,
k: i32,
lda: i32,
ldb: i32,
ldd: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_a: isize,
batch_stride_b: isize,
batch_stride_d: isize,
swizzle_log: i32,
gemm_k_iterations_aligned: i32,
batch_ndim: i32,
}
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, false)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
// rhs has shape b, k, n
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, false)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
let constants = Some(ConstantValues::new(vec![
(10, Value::Bool(/* has_batch */ b > 1)),
(100, Value::Bool(/* use_out_source */ false)),
(110, Value::Bool(/* do_axpby */ false)),
(200, Value::Bool(/* align_m */ m % bm == 0)),
(201, Value::Bool(/* align_n */ n % bn == 0)),
(202, Value::Bool(/* align_k */ k % bk == 0)),
(300, Value::Bool(/* do_gather */ false)),
]));
let swizzle_log = 0;
let tile = 1 << swizzle_log;
let tn = n.div_ceil(bn);
let tm = m.div_ceil(bm);
let tn = tn * tile;
let tm = tm.div_ceil(tile);
let batch_stride_a = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3]
} else {
m * k
};
let batch_stride_b = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3]
} else {
n * k
};
let gemm_params = GemmParams {
m: m as i32,
n: n as i32,
k: k as i32,
lda,
ldb,
ldd: n as i32,
tiles_n: tn as i32,
tiles_m: tm as i32,
swizzle_log,
batch_stride_a: batch_stride_a as isize,
batch_stride_b: batch_stride_b as isize,
batch_stride_d: (m * n) as isize,
batch_ndim: 1i32,
gemm_k_iterations_aligned: (k / bk) as i32,
};
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
// TODO(laurent): generate the name
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
let name = match (dtype, a_trans, b_trans) {
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
};
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(3, Some(output), 0);
encoder.set_bytes(
4,
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const c_void,
);
encoder.set_bytes(
6, // batch_shape
std::mem::size_of::<i32>() as u64,
&(b as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
);
let grid_size = MTLSize {
width: tn as u64,
height: tm as u64,
depth: /* batch_size_out */ b as u64,
};
let group_size = MTLSize {
width: 32,
height: wn,
depth: wm,
};
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
Ok(())
}
pub fn call_const_fill(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (output, v, length));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[cfg(test)]
mod tests;

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
use super::*;
use half::{bf16, f16};
use metal::MTLResourceOptions;
use rand::Rng;
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@ -330,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
#[test]
fn cast_f32() {
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -361,7 +360,7 @@ fn cast_f32() {
#[test]
fn cast_f16() {
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -392,7 +391,7 @@ fn cast_f16() {
#[test]
fn cast_bf16() {
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -423,7 +422,7 @@ fn cast_bf16() {
#[test]
fn cast_u32() {
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -454,7 +453,7 @@ fn cast_u32() {
#[test]
fn cast_u8() {
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -485,7 +484,7 @@ fn cast_u8() {
#[test]
fn cast_i64() {
let v_f64 = [1.0f64, 2.0, 3.0];
let v_f64 = vec![1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@ -912,7 +911,7 @@ fn softmax() {
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@ -923,7 +922,7 @@ fn softmax() {
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
);
let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1046,15 +1045,14 @@ fn where_cond_u32_f32() {
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
#[allow(clippy::too_many_arguments)]
fn run_gemm<T: Clone>(
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: &[usize],
lhs_stride: Vec<usize>,
lhs_offset: usize,
rhs: &[T],
rhs_stride: &[usize],
rhs_stride: Vec<usize>,
rhs_offset: usize,
) -> Vec<T> {
let device = device();
@ -1081,10 +1079,10 @@ fn run_gemm<T: Clone>(
&kernels,
name,
(b, m, n, k),
lhs_stride,
&lhs_stride,
lhs_offset,
&lhs,
rhs_stride,
&rhs_stride,
rhs_offset,
&rhs,
&output,
@ -1107,10 +1105,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
lhs_stride,
0,
&rhs,
&rhs_stride,
rhs_stride,
0,
);
assert_eq!(
@ -1127,10 +1125,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
lhs_stride,
0,
&rhs,
&rhs_stride,
rhs_stride,
0,
);
assert_eq!(
@ -1152,10 +1150,10 @@ fn gemm() {
"sgemm",
(1, m, n, k),
&lhs,
&lhs_stride,
lhs_stride,
0,
&rhs,
&rhs_stride,
rhs_stride,
12 * 4,
);
assert_eq!(
@ -1164,27 +1162,25 @@ fn gemm() {
);
// bgemm sanity test
if false {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_gemm(
"bgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
0,
&rhs,
&rhs_stride,
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_gemm(
"bgemm",
(b, m, n, k),
&lhs,
lhs_stride,
0,
&rhs,
rhs_stride,
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
@ -1196,10 +1192,10 @@ fn gemm() {
"hgemm",
(b, m, n, k),
&lhs,
&lhs_stride,
lhs_stride,
0,
&rhs,
&rhs_stride,
rhs_stride,
0,
);
assert_eq!(
@ -1208,204 +1204,6 @@ fn gemm() {
);
}
#[allow(clippy::too_many_arguments)]
fn run_mlx_gemm<T: Clone>(
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: &[usize],
lhs_offset: usize,
rhs: &[T],
rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_mlx_gemm(
&device,
command_buffer,
&kernels,
dtype,
(b, m, n, k),
lhs_stride,
lhs_offset,
&lhs,
rhs_stride,
rhs_offset,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
use rand::SeedableRng;
use rand_distr::Distribution;
let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
let v1: Vec<f32> = run_mlx_gemm(
dtype,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
let v2: Vec<f32> = run_gemm(
"sgemm",
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[k * n, n, 1],
0,
);
for (a, b) in v1.iter().zip(v2.iter()) {
let diff = (a - b).abs();
assert_eq!((diff * 1e4).round(), 0.)
}
}
#[test]
fn mlx_vs_mfa() {
mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
}
#[test]
fn mlx_gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_mlx_gemm(
GemmDType::F32,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
// OFFSET
let (b, m, n, k) = (2, 2, 4, 3);
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_mlx_gemm(
GemmDType::F32,
(1, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
12 * 4,
);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
// bgemm sanity test
{
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::BF16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_bf16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
{
// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);
let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
let results = run_mlx_gemm(
GemmDType::F16,
(b, m, n, k),
&lhs,
&[m * k, k, 1],
0,
&rhs,
&[n * k, n, 1],
0,
);
assert_eq!(
approx_f16(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
}
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
@ -1480,7 +1278,7 @@ fn random() {
variance.sqrt()
}
let shape = [1024, 10];
let shape = vec![1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
@ -1836,7 +1634,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@ -1856,7 +1654,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
let expected = [5.0, 7.0, 13.0, 15.0]
let expected = vec![5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@ -1879,7 +1677,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -1899,7 +1697,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
let expected = [5.0, 7.0, 13.0, 15.0]
let expected = vec![5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -2018,7 +1816,7 @@ fn avg_pool2d_f16() {
&strides,
"avg_pool2d_f16",
);
let expected = [
let expected = vec![
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@ -2043,7 +1841,7 @@ fn avg_pool2d_bf16() {
&strides,
"avg_pool2d_bf16",
);
let expected = [
let expected = vec![
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@ -2181,14 +1979,14 @@ fn conv_transpose1d_f32() {
#[test]
fn conv_transpose1d_f16() {
let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
@ -2209,7 +2007,7 @@ fn conv_transpose1d_f16() {
"conv_transpose1d_f16",
);
let expected = [1., 4., 10., 20., 25., 24., 16.]
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@ -2218,14 +2016,14 @@ fn conv_transpose1d_f16() {
#[test]
fn conv_transpose1d_bf16() {
let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
@ -2246,7 +2044,7 @@ fn conv_transpose1d_bf16() {
"conv_transpose1d_bf16",
);
let expected = [1., 4., 10., 20., 25., 24., 16.]
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
@ -2308,33 +2106,3 @@ fn conv_transpose1d_u32() {
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}
#[test]
fn const_fill() {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);
test::<i64, _>("fill_i64", |v| v as i64);
test::<f16, _>("fill_f16", f16::from_f32);
test::<bf16, _>("fill_bf16", bf16::from_f32);
test::<f32, _>("fill_f32", |v| v);
}

View File

@ -56,7 +56,7 @@ template <typename T> METAL_FUNC T gelu(T x) {
T x_cube = x_sq * x;
T alpha = x + static_cast<T>(0.044715) * x_cube;
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(precise::tanh(beta)));
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
template <typename T> METAL_FUNC T relu(T in){
if (in < 0) {
@ -154,6 +154,7 @@ UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
UNARY_OP(relu)
UNARY_OP(sign)
@ -163,11 +164,6 @@ UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
// tanh may create NaN on large values, e.g. 45 rather than outputing 1.
// This has been an issue for the encodec example.
UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided);
UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
@ -189,6 +185,7 @@ BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
BFLOAT_UNARY_OP(sign)
@ -196,7 +193,5 @@ BFLOAT_UNARY_OP(sigmoid)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
COPY2D(copy2d_bf16, bfloat)
#endif

View File

@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
@ -61,14 +61,18 @@ pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
}
}
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub trait EncoderParam {
pub(crate) trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
@ -161,25 +165,20 @@ pub trait EncoderProvider {
type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_>;
fn encoder<'a>(&'a self) -> Self::Encoder<'a>;
}
pub struct WrappedEncoder<'a> {
inner: &'a ComputeCommandEncoderRef,
end_encoding_on_drop: bool,
}
pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef);
impl<'a> Drop for WrappedEncoder<'a> {
fn drop(&mut self) {
if self.end_encoding_on_drop {
self.inner.end_encoding()
}
self.0.end_encoding()
}
}
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
self.inner
&self.0
}
}
@ -187,11 +186,8 @@ impl EncoderProvider for &metal::CommandBuffer {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self.new_compute_command_encoder(),
end_encoding_on_drop: true,
}
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
WrappedEncoder(self.new_compute_command_encoder())
}
}
@ -199,22 +195,7 @@ impl EncoderProvider for &metal::CommandBufferRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self.new_compute_command_encoder(),
end_encoding_on_drop: true,
}
}
}
impl EncoderProvider for &ComputeCommandEncoderRef {
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self,
end_encoding_on_drop: false,
}
fn encoder<'a>(&'a self) -> Self::Encoder<'a> {
WrappedEncoder(self.new_compute_command_encoder())
}
}

View File

@ -1,4 +1,4 @@
use candle::{Device, Result, Tensor};
use candle::{Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
@ -145,225 +145,3 @@ impl KvCache {
self.v.reset();
}
}
#[derive(Debug, Clone)]
pub struct RotatingCache {
all_data: Option<Tensor>,
dim: usize,
// `offset` is the current write index in the buffer
offset: usize,
// The total size of the sequence seen so far.
current_seq_len: usize,
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
// sequence to grow past this limit.
max_seq_len: usize,
}
impl RotatingCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
offset: 0,
current_seq_len: 0,
max_seq_len,
}
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => {
if self.current_seq_len >= self.max_seq_len {
Some(d.clone())
} else {
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
}
}
};
Ok(data)
}
pub fn reset(&mut self) {
self.offset = 0;
self.current_seq_len = 0;
self.all_data = None;
}
pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();
self.current_seq_len += seq_len;
if seq_len >= self.max_seq_len {
let to_copy = src
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
.contiguous()?;
ad.slice_set(&to_copy, self.dim, 0)?;
self.offset = 0;
// Here we return `src` rather than `ad` so that all the past can be used.
Ok(src.clone())
} else {
let rem_len = self.max_seq_len - self.offset;
if seq_len <= rem_len {
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
self.offset = (self.offset + seq_len) % self.max_seq_len;
} else {
// We have to make two copies here as we go over the boundary of the cache.
if rem_len > 0 {
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
ad.slice_set(&src1, self.dim, self.offset)?;
}
let src2 = src
.narrow(self.dim, rem_len, seq_len - rem_len)?
.contiguous()?;
ad.slice_set(&src2, self.dim, 0)?;
self.offset = seq_len - rem_len;
}
if self.current_seq_len >= self.max_seq_len {
Ok(ad.clone())
} else {
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
}
}
}
fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let upd_offset = (self.offset + size1) % self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|pos_src| {
// The absolute position of the elements that will get added to the cache.
let pos_src = self.current_seq_len + pos_src;
(0..size2).map(move |pos_cache_rel| {
// The absolute position of the cache elements after the addition.
let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
let pos_cache = if pos_cache_rel < upd_offset {
pos_cache
} else {
pos_cache - self.max_seq_len
};
u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
let mask = if seq_len == 1 {
None
} else {
let mask = if seq_len < self.max_seq_len {
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
self.get_mask_rel(seq_len, cache_out_len, device)?
} else {
self.get_mask_abs(seq_len, seq_len, device)?
};
Some(mask)
};
Ok(mask)
}
}
#[derive(Debug, Clone)]
pub struct RotatingKvCache {
k: RotatingCache,
v: RotatingCache,
}
impl RotatingKvCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, max_seq_len);
let v = RotatingCache::new(dim, max_seq_len);
Self { k, v }
}
pub fn k_cache(&self) -> &RotatingCache {
&self.k
}
pub fn v_cache(&self) -> &RotatingCache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
let out_k = self.k.append(k)?;
let out_v = self.v.append(v)?;
Ok((out_k, out_v))
}
pub fn offset(&self) -> usize {
self.k.offset()
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
self.k.attn_mask(seq_len, device)
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}

View File

@ -543,23 +543,15 @@ impl candle::CustomOp2 for RmsNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (block_size, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
n_cols as i32,
block_size as i32,
self.eps,
);
let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@ -784,24 +776,15 @@ impl candle::CustomOp3 for LayerNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (block_size, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
&beta,
n_cols as i32,
block_size as i32,
self.eps,
);
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)

View File

@ -55,10 +55,6 @@ pub struct LSTMState {
}
impl LSTMState {
pub fn new(h: Tensor, c: Tensor) -> Self {
LSTMState { h, c }
}
/// The hidden state vector, which is also the output of the LSTM.
pub fn h(&self) -> &Tensor {
&self.h
@ -70,12 +66,6 @@ impl LSTMState {
}
}
#[derive(Debug, Clone, Copy)]
pub enum Direction {
Forward,
Backward,
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy)]
pub struct LSTMConfig {
@ -84,7 +74,6 @@ pub struct LSTMConfig {
pub b_ih_init: Option<super::Init>,
pub b_hh_init: Option<super::Init>,
pub layer_idx: usize,
pub direction: Direction,
}
impl Default for LSTMConfig {
@ -95,7 +84,6 @@ impl Default for LSTMConfig {
b_ih_init: Some(super::Init::Const(0.)),
b_hh_init: Some(super::Init::Const(0.)),
layer_idx: 0,
direction: Direction::Forward,
}
}
}
@ -108,7 +96,6 @@ impl LSTMConfig {
b_ih_init: None,
b_hh_init: None,
layer_idx: 0,
direction: Direction::Forward,
}
}
}
@ -116,7 +103,7 @@ impl LSTMConfig {
/// A Long Short-Term Memory (LSTM) layer.
///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms)]
#[allow(clippy::upper_case_acronyms, unused)]
#[derive(Clone, Debug)]
pub struct LSTM {
w_ih: Tensor,
@ -129,62 +116,6 @@ pub struct LSTM {
dtype: DType,
}
impl LSTM {
/// Creates a LSTM layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let layer_idx = config.layer_idx;
let direction_str = match config.direction {
Direction::Forward => "",
Direction::Backward => "_reverse",
};
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_ih_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_hh_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &LSTMConfig {
&self.config
}
}
/// Creates a LSTM layer.
pub fn lstm(
in_dim: usize,
@ -192,7 +123,39 @@ pub fn lstm(
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<LSTM> {
LSTM::new(in_dim, hidden_dim, config, vb)
let layer_idx = config.layer_idx;
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
}
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
}
None => None,
};
Ok(LSTM {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
impl RNN for LSTM {
@ -286,7 +249,7 @@ impl GRUConfig {
/// A Gated Recurrent Unit (GRU) layer.
///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms)]
#[allow(clippy::upper_case_acronyms, unused)]
#[derive(Clone, Debug)]
pub struct GRU {
w_ih: Tensor,
@ -299,56 +262,41 @@ pub struct GRU {
dtype: DType,
}
impl GRU {
/// Creates a GRU layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &GRUConfig {
&self.config
}
}
/// Creates a GRU layer.
pub fn gru(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<GRU> {
GRU::new(in_dim, hidden_dim, config, vb)
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(GRU {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
impl RNN for GRU {

Some files were not shown because too many files have changed in this diff Show More