Compare commits

..

6 Commits

Author SHA1 Message Date
5221146cfa Cuda quantization padding fix. 2024-09-25 23:35:16 +02:00
fd3b53f48b Fix for the quantized model. 2024-09-25 12:34:46 +02:00
c6019e9635 Use the newly minted gguf file. 2024-09-25 12:08:20 +02:00
8cc560bb8c Hook the quantized model. 2024-09-25 11:24:50 +02:00
0bd61bae29 More generic sampling. 2024-09-25 11:15:37 +02:00
fa1e0e438e Quantized version of flux. 2024-09-25 11:07:49 +02:00
234 changed files with 1039 additions and 15206 deletions

View File

@ -9,8 +9,7 @@ jobs:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-g4dn-2xlarge
runs-on: [single-gpu, nvidia-gpu, t4, ci]
container:
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0

View File

@ -16,9 +16,6 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@ -37,13 +34,7 @@ jobs:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [stable]
steps:
- name: Delete huge unnecessary tools folder
if: runner.os == 'Linux'
run: rm -rf /opt/hostedtoolcache
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
profile: minimal

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.8.1"
version = "0.7.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,20 +33,20 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.8.1" }
candle-datasets = { path = "./candle-datasets", version = "0.8.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.1" }
candle-kernels = { path = "./candle-kernels", version = "0.8.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.1" }
candle-nn = { path = "./candle-nn", version = "0.8.1" }
candle-onnx = { path = "./candle-onnx", version = "0.8.1" }
candle-transformers = { path = "./candle-transformers", version = "0.8.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" }
candle-nn = { path = "./candle-nn", version = "0.7.1" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
@ -70,9 +70,6 @@ tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.0.2"
ug-cuda = "0.0.2"
ug-metal = "0.0.2"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -2,8 +2,7 @@
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)
![License](https://img.shields.io/crates/l/candle-core.svg)
Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
@ -188,7 +187,6 @@ And then head over to
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible.
If you have an addition to this list, please submit a pull request.

View File

@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas
```rust
# extern crate candle_core;
# extern crate candle_hf_hub;
use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate candle_hf_hub;
# use candle_hf_hub::api::sync::Api;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());

View File

@ -28,9 +28,6 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
ug = { workspace = true }
ug-cuda = { workspace = true, optional = true }
ug-metal = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }
@ -42,11 +39,11 @@ criterion = { workspace = true }
[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"

View File

@ -1,5 +1,3 @@
//! Traits to Define Backend Behavior
//!
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};

View File

@ -1,4 +1,4 @@
//! Methods for backpropagation of gradients.
/// Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;

View File

@ -1,5 +1,3 @@
//! 1D and 2D Convolutions
//!
use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]

View File

@ -1,5 +1,3 @@
//! Traits and methods for CPU-backed Tensors
pub mod erf;
pub mod kernels;

View File

@ -1,4 +1,3 @@
//! Implementation of Backend Fns for CPU
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
@ -66,7 +65,7 @@ impl Map2U8 for Cmp {
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
impl<I: IntDType> Map2 for WCond<'_, I> {
impl<'a, I: IntDType> Map2 for WCond<'a, I> {
const OP: &'static str = "where";
#[inline(always)]
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
@ -216,7 +215,7 @@ struct ReduceSum<'a> {
reduce_dims_and_stride: Vec<(usize, usize)>,
}
impl ReduceSum<'_> {
impl<'a> ReduceSum<'a> {
#[inline(always)]
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
where
@ -281,7 +280,7 @@ impl ReduceSum<'_> {
}
}
impl Map1 for ReduceSum<'_> {
impl<'a> Map1 for ReduceSum<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
self.fold_impl(src, src_l, T::zero())
@ -454,7 +453,7 @@ struct Gather<'a, I: IntDType> {
dim: usize,
}
impl<I: IntDType> Map1 for Gather<'_, I> {
impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
@ -507,7 +506,7 @@ struct IndexSelect<'a, T: IntDType> {
dim: usize,
}
impl<I: IntDType> Map1 for IndexSelect<'_, I> {
impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
@ -560,7 +559,7 @@ struct ScatterAdd<'a, I: IntDType> {
dim: usize,
}
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
const OP: &'static str = "scatter-add";
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let dst_len = l1.shape().elem_count();
@ -616,7 +615,7 @@ struct IndexAdd<'a, I: IntDType> {
dim: usize,
}
impl<I: IntDType> Map2 for IndexAdd<'_, I> {
impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
const OP: &'static str = "index-add";
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
// v1, l1 -> self
@ -736,7 +735,7 @@ fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl Map2 for Conv1D<'_> {
impl<'a> Map2 for Conv1D<'a> {
const OP: &'static str = "conv1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
@ -960,7 +959,7 @@ impl Map1 for Col2Im1D {
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl Map2 for ConvTranspose1D<'_> {
impl<'a> Map2 for ConvTranspose1D<'a> {
const OP: &'static str = "conv_transpose1d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
@ -1029,7 +1028,7 @@ impl Map2 for ConvTranspose1D<'_> {
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl Map2 for Conv2D<'_> {
impl<'a> Map2 for Conv2D<'a> {
const OP: &'static str = "conv2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
@ -1117,7 +1116,7 @@ impl Map2 for Conv2D<'_> {
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl Map2 for ConvTranspose2D<'_> {
impl<'a> Map2 for ConvTranspose2D<'a> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;

View File

@ -26,7 +26,6 @@ impl From<cudarc::driver::DriverError> for crate::Error {
pub(crate) fn launch_conv2d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
@ -49,7 +48,7 @@ pub(crate) fn launch_conv2d<
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
let conv = cudnn.create_conv2d::<T>(
/* pad */ [params.padding as i32, params.padding as i32],
/* stride */ [params.stride as i32, params.stride as i32],
/* dilation */ [params.dilation as i32, params.dilation as i32],
@ -63,18 +62,18 @@ pub(crate) fn launch_conv2d<
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(
cudnn.create_4d_tensor_ex(
x_shape,
[s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32],
)?
};
let w = cudnn.create_4d_filter::<T>(
let w = cudnn.create_4d_filter(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
@ -84,7 +83,7 @@ pub(crate) fn launch_conv2d<
],
)?;
let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32);
let y = cudnn.create_4d_tensor::<T>(
let y = cudnn.create_4d_tensor(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, h_out, w_out],
)?;

View File

@ -51,27 +51,6 @@ impl CudaDevice {
self.device.clone()
}
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
}
pub fn id(&self) -> DeviceId {
self.id
}
@ -165,20 +144,6 @@ impl CudaDevice {
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
})
}
}
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;

View File

@ -1,5 +1,3 @@
//! Implementation of Backend traits for CUDA device
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
@ -255,7 +253,7 @@ impl Map1 for Powf {
}
struct FastReduce<'a>(&'a [usize], ReduceOp);
impl Map1Any for FastReduce<'_> {
impl<'a> Map1Any for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
@ -350,7 +348,7 @@ impl<U: UnaryOpT> Map1 for U {
}
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map1 for IndexSelect<'_> {
impl<'a> Map1 for IndexSelect<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -410,7 +408,7 @@ impl Map1 for IndexSelect<'_> {
}
struct Gather<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map1 for Gather<'_> {
impl<'a> Map1 for Gather<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
@ -461,7 +459,7 @@ impl Map1 for Gather<'_> {
}
struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map2InPlace for IndexAdd<'_> {
impl<'a> Map2InPlace for IndexAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@ -509,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
}
struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map2InPlace for ScatterAdd<'_> {
impl<'a> Map2InPlace for ScatterAdd<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
@ -554,7 +552,7 @@ impl Map2InPlace for ScatterAdd<'_> {
}
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl Map2 for Conv1D<'_> {
impl<'a> Map2 for Conv1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -595,7 +593,7 @@ impl Map2 for Conv1D<'_> {
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl Map2 for Conv2D<'_> {
impl<'a> Map2 for Conv2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -660,7 +658,7 @@ impl Map1 for Col2Im1D {
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl Map2 for ConvTranspose1D<'_> {
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -709,7 +707,7 @@ impl Map2 for ConvTranspose1D<'_> {
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl Map2 for ConvTranspose2D<'_> {
impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
@ -850,7 +848,7 @@ impl Map1 for UpsampleNearest2D {
}
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl Map2 for WhereCond<'_> {
impl<'a> Map2 for WhereCond<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
t: &CudaSlice<T>,
@ -1524,7 +1522,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
}
@ -1532,10 +1530,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
// version.
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
crate::cudnn::launch_conv2d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::BF16(out)
}
@ -1543,7 +1538,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
}
@ -1551,7 +1546,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
}
@ -1559,7 +1554,7 @@ impl BackendStorage for CudaStorage {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
}

View File

@ -375,110 +375,3 @@ impl Tensor {
)
}
}
pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
}
impl UgIOp1 {
#[allow(unused)]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(feature = "metal")]
{
let device = device.as_metal_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
}
}
impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on metal/cuda at the moment")
}
#[cfg(feature = "metal")]
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
use crate::backend::BackendStorage;
use candle_metal_kernels::utils::EncoderProvider;
let elem_count = layout.shape().elem_count();
if sto.dtype() != crate::DType::F32 {
// TODO: support more dtypes.
crate::bail!("input is not a f32 tensor")
}
let device = sto.device();
println!("here");
let command_buffer = device.command_buffer()?;
let command_buffer = &command_buffer;
let encoder = command_buffer.encoder();
let encoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&self.func);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let grid_dims = metal::MTLSize {
width: g as u64,
height: 1,
depth: 1,
};
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);
Ok(())
}
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;
let elem_count = layout.shape().elem_count();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}

View File

@ -11,7 +11,6 @@ pub enum DeviceLocation {
Metal { gpu_id: usize },
}
/// Cpu, Cuda, or Metal
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
@ -131,26 +130,6 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
match self {
Self::Cuda(d) => Ok(d),
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
}
}
pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
match self {
Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
Self::Cpu => crate::bail!("expected a metal device, got cpu"),
Self::Metal(d) => Ok(d),
}
}
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}

View File

@ -1,7 +1,6 @@
//! Pretty printing of tensors
//!
//! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py).
//!
/// Pretty printing of tensors
/// This implementation should be in line with the PyTorch version.
/// https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py
use crate::{DType, Result, Tensor, WithDType};
use half::{bf16, f16};

View File

@ -1,5 +1,3 @@
//! Implementation of the Cuda backend when Cuda support has not been compiled in.
//!
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
@ -16,12 +14,6 @@ macro_rules! fail {
};
}
impl CudaDevice {
pub fn new_with_stream(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
}
impl crate::backend::BackendStorage for CudaStorage {
type Device = CudaDevice;

View File

@ -1,4 +1,3 @@
//! Candle-specific Error and Result
use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
#[derive(Debug, Clone)]
@ -166,9 +165,6 @@ pub enum Error {
#[error("Metal error {0}")]
Metal(#[from] MetalError),
#[error(transparent)]
Ug(#[from] ug::Error),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
@ -183,10 +179,6 @@ pub enum Error {
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),
/// Utf8 parse error.
#[error(transparent)]
FromUtf8(#[from] std::string::FromUtf8Error),
/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),

View File

@ -1,4 +1,3 @@
//! Tensor Layouts including contiguous or sparse strides
use crate::{Error, Result, Shape};
#[derive(Debug, PartialEq, Eq, Clone)]
@ -36,12 +35,6 @@ impl Layout {
self.shape.dims()
}
/// The dimension size for a specified dimension index.
pub fn dim<D: crate::shape::Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(&self.shape, "dim")?;
Ok(self.dims()[dim])
}
pub fn shape(&self) -> &Shape {
&self.shape
}

View File

@ -7,8 +7,8 @@
//!
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
//! let c = a.matmul(&b)?;
//!
//! let c = a.matmul(&b)?;
//! # Ok(())}
//! ```
//!
@ -32,20 +32,6 @@
//! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
//!
//! ## Other Crates
//!
//! Candle consists of a number of crates. This crate holds core the common data structures but you may wish
//! to look at the docs for the other crates which can be found here:
//!
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
//!
#[cfg(feature = "accelerate")]
mod accelerate;
@ -91,7 +77,7 @@ mod variable;
pub use cuda_backend::cudnn;
pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
@ -140,7 +126,7 @@ impl ToUsize2 for (usize, usize) {
}
}
/// Defining a module with forward method using a single argument.
// A simple trait defining a module with forward method using a single argument.
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
@ -160,8 +146,8 @@ impl<M: Module> Module for Option<&M> {
}
}
/// A single forward method using a single single tensor argument and a flag to
/// separate the training and evaluation behaviors.
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}

View File

@ -144,28 +144,6 @@ impl MetalDevice {
self.use_mlx_mm = use_mlx_mm
}
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<metal::ComputePipelineState> {
let mut buf = vec![];
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
let metal_code = String::from_utf8(buf)?;
let lib = self
.device
.new_library_with_source(&metal_code, &metal::CompileOptions::new())
.map_err(MetalError::from)?;
let func = lib
.get_function(func_name, None)
.map_err(MetalError::from)?;
let pl = self
.device
.new_compute_pipeline_state_with_function(&func)
.map_err(MetalError::from)?;
Ok(pl)
}
pub fn id(&self) -> DeviceId {
self.id
}

View File

@ -1,5 +1,3 @@
//! Implementation of Backend traits for Metal
//!
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
@ -1239,12 +1237,11 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "gather")?;
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16",
(DType::U32, DType::U32) => "gather_u32_u32",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
@ -1284,7 +1281,6 @@ impl BackendStorage for MetalStorage {
(DType::U8, DType::F32) => "sa_u8_f32",
(DType::U8, DType::F16) => "sa_u8_f16",
(DType::U8, DType::BF16) => "sa_u8_bf16",
(DType::U32, DType::U32) => "sa_u32_u32",
(DType::U32, DType::F32) => "sa_u32_f32",
(DType::U32, DType::F16) => "sa_u32_f16",
(DType::U32, DType::BF16) => "sa_u32_bf16",
@ -1328,23 +1324,14 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::U8) => "is_u8_u8",
(DType::U8, DType::U32) => "is_u8_u32",
(DType::U8, DType::I64) => "is_u8_i64",
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
(DType::U32, DType::U8) => "is_u32_u8",
(DType::U32, DType::U32) => "is_u32_u32",
(DType::U32, DType::I64) => "is_u32_i64",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(DType::I64, DType::U8) => "is_i64_u8",
(DType::I64, DType::U32) => "is_i64_u32",
(DType::I64, DType::I64) => "is_i64_i64",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",
@ -1878,9 +1865,9 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let kernels = Arc::new(Kernels::new());
let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true,
Ok(_) => false,
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
Ok(_) => true,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
@ -1930,38 +1917,10 @@ impl BackendDevice for MetalDevice {
))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;
Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {

View File

@ -1,5 +1,3 @@
//! Tensor Opertion Enums and Traits
//!
#![allow(clippy::redundant_closure_call)]
use crate::Tensor;
use half::{bf16, f16};

View File

@ -1,4 +1,4 @@
//! Just enough pickle support to be able to read PyTorch checkpoints.
// Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};

View File

@ -6,15 +6,9 @@ use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
#[derive(Clone, Debug)]
struct PaddedCudaSlice {
inner: CudaSlice<u8>,
len: usize,
}
#[derive(Clone, Debug)]
pub struct QCudaStorage {
data: PaddedCudaSlice,
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
@ -36,11 +30,14 @@ pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
pub const MATRIX_ROW_PADDING: usize = 512;
fn ceil_div(p: usize, q: usize) -> usize {
p.div_ceil(q)
(p + q - 1) / q
}
fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q
// Overallocate by q rather than just padding by q as this should pad the last row
// and we don't have enough information here to know how many elements to add :(
// ceil_div(p, q) * q
p + q
}
fn quantize_q8_1(
@ -67,7 +64,7 @@ fn quantize_q8_1(
}
fn dequantize_f32(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
@ -110,21 +107,21 @@ fn dequantize_f32(
};
if is_k {
let params = (&data.inner, &dst);
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_f16(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
@ -167,21 +164,21 @@ fn dequantize_f16(
};
if is_k {
let params = (&data.inner, &dst);
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mul_mat_vec(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
@ -190,7 +187,7 @@ fn dequantize_mul_mat_vec(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
@ -219,13 +216,13 @@ fn dequantize_mul_mat_vec(
shared_mem_bytes: 0,
};
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
let params = (data, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_vec_via_q8_1(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
@ -235,7 +232,7 @@ fn mul_mat_vec_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
}
@ -282,7 +279,7 @@ fn mul_mat_vec_via_q8_1(
};
let params = (
&data.inner,
data,
&y_q8_1,
&dst,
/* ncols_x */ ncols as i32,
@ -296,7 +293,7 @@ fn mul_mat_vec_via_q8_1(
#[allow(clippy::too_many_arguments)]
fn mul_mat_via_q8_1(
data: &PaddedCudaSlice,
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
@ -307,7 +304,7 @@ fn mul_mat_via_q8_1(
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
@ -321,7 +318,7 @@ fn mul_mat_via_q8_1(
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
@ -351,7 +348,7 @@ fn mul_mat_via_q8_1(
};
let params = (
/* vx */ &data.inner,
/* vx */ data,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
@ -367,14 +364,9 @@ fn mul_mat_via_q8_1(
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data: PaddedCudaSlice {
inner,
len: size_in_bytes,
},
data,
device: device.clone(),
dtype,
})
@ -414,10 +406,7 @@ impl QCudaStorage {
}
// Run the dequantization on cpu.
let buffer = self
.device
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
.w()?;
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
@ -453,26 +442,18 @@ impl QCudaStorage {
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src_len = pad(src.len(), MATRIX_ROW_PADDING);
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
self.device
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
};
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len
self.data.len()
}
pub fn fwd(
@ -595,19 +576,11 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
.w()?;
let data = device.htod_sync_copy(data).w()?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
inner,
len: data.len(),
},
data,
device: device.clone(),
dtype,
dtype: T::DTYPE,
}))
}
@ -707,28 +680,4 @@ mod test {
assert_eq!(vs[15], 13138824.0);
Ok(())
}
// The following test used to fail under compute-sanitizer until #2526.
#[test]
fn cuda_mm_q8_1_pad() -> Result<()> {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ x_rows,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ y_cols,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
Ok(())
}
}

View File

@ -134,7 +134,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
super::QTensor::new(data, dims)
}
/// Creates a Tensor from a raw GGML tensor.
/// Creates a [Tensor] from a raw GGML tensor.
pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],

View File

@ -1,5 +1,6 @@
//! Support for the [GGUF file format](https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md).
//! Support for the GGUF file format.
//!
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
@ -457,7 +458,7 @@ impl Content {
Some(Value::I32(v)) if *v >= 0 => *v as u64,
_ => DEFAULT_ALIGNMENT,
};
let tensor_data_offset = position.div_ceil(alignment) * alignment;
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
Ok(Self {
magic,
metadata,

View File

@ -1850,8 +1850,8 @@ pub fn matmul<T: GgmlType>(
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
}
let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
// TODO: Do not make this copy if the DotType is f32.
// TODO: Pre-allocate this.
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];

View File

@ -1,4 +1,3 @@
//! Code for GGML and GGUF files
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;

View File

@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_blocks = ys.len();
// Validate that the input is the right size
if expected_blocks != actual_blocks {
if actual_blocks < expected_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}

View File

@ -1,14 +1,3 @@
//! Module to load `safetensor` files into CPU/GPU memory.
//!
//! There are multiple ways to load tensors from safetensor files:
//! - `load` function for loading directly into memory and returning a HashMap of tensors
//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
//! - `SliceSafetensors` for working with in-memory buffers
//! - `BufferedSafetensors` for owning a buffer of data
//!
//! Tensors can also be serialized to safetensor format using the `save` function or
//! `Tensor::save_safetensors` method.
//!
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
use safetensors::tensor::SafeTensors;
@ -182,7 +171,7 @@ pub trait Load {
fn load(&self, device: &Device) -> Result<Tensor>;
}
impl Load for st::TensorView<'_> {
impl<'a> Load for st::TensorView<'a> {
fn load(&self, device: &Device) -> Result<Tensor> {
convert(self, device)
}

View File

@ -1,5 +1,3 @@
//! TensorScalar Enum and Trait
//!
use crate::{Result, Tensor, WithDType};
pub enum TensorScalar {

View File

@ -142,12 +142,6 @@ impl Shape {
&self.0
}
/// The dimension size for a specified dimension index.
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
let dim = dim.to_index(self, "dim")?;
Ok(self.dims()[dim])
}
/// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()

View File

@ -1,5 +1,3 @@
//! StreamTensror useful for streaming ops.
//!
use crate::{Result, Shape, Tensor};
pub trait Dim: crate::shape::Dim + Copy {}

View File

@ -32,7 +32,7 @@ impl<'a> StridedIndex<'a> {
}
}
impl Iterator for StridedIndex<'_> {
impl<'a> Iterator for StridedIndex<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {

View File

@ -242,7 +242,7 @@ impl Tensor {
Self::zeros_impl(shape, dtype, device, false)
}
/// Creates a new tensor filled with zeros with same shape, dtype, and device as the other
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
/// tensor.
///
/// ```rust
@ -1520,15 +1520,14 @@ impl Tensor {
/// # Arguments
///
/// * `self` - The input tensor.
/// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
/// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
/// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
/// but can have a different number of elements on the target dimension.
/// * `dim` - the target dimension.
///
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
/// dimension `dim` by the values in `indexes`.
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
let self_dims = self.dims();
let indexes_dims = indexes.dims();
let mismatch = if indexes_dims.len() != self_dims.len() {
@ -1536,7 +1535,7 @@ impl Tensor {
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
if i != dim && d1 < d2 {
if i != dim && d1 != d2 {
mismatch = true;
break;
}
@ -1760,42 +1759,6 @@ impl Tensor {
&self.op
}
/// Computes the max of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.max_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 5.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn max_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.max(0)
}
}
/// Computes the min of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.min_all()?;
/// assert_eq!(tensor.to_scalar::<f32>()?, 0.);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn min_all(&self) -> Result<Tensor> {
if self.rank() == 0 {
Ok(self.clone())
} else {
self.flatten_all()?.min(0)
}
}
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
/// scalar with zero dimensions.
///

View File

@ -1,4 +1,3 @@
//! Useful functions for checking features.
use std::str::FromStr;
pub fn get_num_threads() -> usize {

View File

@ -143,39 +143,3 @@ fn inplace_op1() -> Result<()> {
);
Ok(())
}
#[cfg(any(feature = "cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
let kernel = {
use ug::lang::op;
let layout = ug::Layout::from_shape(&[12]);
let ptr = op::Arg::ptr(ug::DType::F32);
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
let src = op::unary(op::UnaryOp::Exp, src)?;
let st = op::store(ptr.id(), layout, src)?;
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
let device = if candle_core::utils::cuda_is_available() {
Device::new_cuda(0)?
} else if candle_core::utils::metal_is_available() {
Device::new_metal(0)?
} else {
candle_core::bail!("metal/cuda is mandatory for this test")
};
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
to_vec1_round(&t, 2)?,
&[
1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
59874.13
]
);
Ok(())
}

View File

@ -29,36 +29,6 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
],
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
]
],
);
assert_eq!(
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
[
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
],
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
]
],
);
Ok(())
}
@ -1047,280 +1017,6 @@ fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
// Random data
// Dim: 0
let t = Tensor::new(
&[
[
[108_f32, -47., 16., -56., -83., -130., 210.],
[253., 95., 151., 228., -210., -123., -127.],
[-9., -217., 2., -78., 163., 245., -204.],
[-246., 79., -238., 88., -226., -184., 171.],
[8., -48., -153., 234., -34., 166., -153.],
[124., 0., -10., -61., -242., -15., -238.],
],
[
[12., -64., -199., 244., -240., 156., -128.],
[173., -57., 4., -198., 233., -110., 238.],
[95., 82., 0., 240., 53., -211., 209.],
[-122., 167., -212., 227., -144., 61., 118.],
[-63., -146., 200., 244., 168., -167., 116.],
[-125., -147., 110., -253., -178., -250., -18.],
],
[
[57., 86., -50., 56., 92., 205., -78.],
[-137., -156., -18., 248., -61., -239., 14.],
[-248., -30., -50., -70., -251., 250., -83.],
[-221., 67., 72., 59., -24., -154., 232.],
[-144., -23., -74., 5., 93., 171., 205.],
[46., -77., -38., -226., 246., 161., -17.],
],
[
[-153., -231., -236., 161., 126., 2., -22.],
[-229., -41., 209., 164., 234., 160., 57.],
[223., 254., -186., -162., -46., -160., -102.],
[65., 30., 213., -253., 59., 224., -154.],
[-82., -203., -177., 17., 31., -256., -246.],
[176., -135., -65., 54., -56., 210., 76.],
],
[
[-10., -245., 168., 124., -14., -33., -178.],
[25., -43., -39., 132., -89., 169., 179.],
[187., -215., 32., -133., 87., -7., -168.],
[-224., -215., -5., -230., -58., -162., 128.],
[158., -137., -122., -100., -202., -83., 136.],
[30., -185., -144., 250., 209., -40., 127.],
],
[
[-196., 108., -245., 122., 146., -228., 62.],
[-1., -66., 160., 137., 13., -172., -21.],
[244., 199., -164., 28., 119., -175., 198.],
[-62., 253., -162., 195., -95., -230., -211.],
[123., -72., -26., -107., -139., 64., 245.],
[11., -126., -182., 108., -12., 184., -127.],
],
[
[-159., 126., 176., 161., 73., -111., -138.],
[-187., 214., -217., -33., -223., -201., -212.],
[-61., -120., -166., -172., -95., 53., 196.],
[-33., 86., 134., -152., 154., -53., 74.],
[186., -28., -154., -174., 141., -109., 217.],
[82., 35., 252., 145., 181., 74., -87.],
],
],
device,
)?;
let ids = Tensor::new(
&[
[
[6_u32, 6, 4, 3, 4, 4, 6],
[3, 3, 2, 4, 4, 4, 6],
[3, 3, 0, 2, 4, 6, 4],
[2, 5, 1, 2, 6, 6, 1],
[2, 1, 6, 5, 3, 2, 3],
[6, 1, 0, 1, 0, 2, 6],
],
[
[4, 6, 4, 3, 3, 3, 2],
[4, 3, 2, 4, 4, 4, 6],
[2, 3, 0, 2, 4, 6, 4],
[6, 5, 1, 2, 6, 6, 1],
[4, 1, 6, 5, 3, 2, 3],
[1, 1, 0, 1, 0, 2, 6],
],
[
[3, 6, 4, 3, 3, 3, 2],
[2, 3, 2, 4, 4, 4, 6],
[4, 3, 0, 2, 4, 6, 4],
[0, 5, 1, 2, 6, 6, 1],
[6, 1, 6, 5, 3, 2, 3],
[4, 1, 0, 1, 0, 2, 6],
],
[
[0, 6, 4, 3, 3, 3, 2],
[5, 3, 2, 4, 4, 4, 6],
[0, 3, 0, 2, 4, 6, 4],
[3, 5, 1, 2, 6, 6, 1],
[0, 1, 6, 5, 3, 2, 3],
[3, 1, 0, 1, 0, 2, 6],
],
],
device,
)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[
[-159_f32, 126., 168., 161., -14., -33., -138.],
[-229., -41., -18., 132., -89., 169., -212.],
[223., 254., 2., -70., 87., 53., -168.],
[-221., 253., -212., 59., 154., -53., 118.],
[-144., -146., -154., -107., 31., 171., -246.],
[82., -147., -10., -253., -242., 161., -87.]
],
[
[-10., 126., 168., 161., 126., 2., -78.],
[25., -41., -18., 132., -89., 169., -212.],
[-248., 254., 2., -70., 87., 53., -168.],
[-33., 253., -212., 59., 154., -53., 118.],
[158., -146., -154., -107., 31., 171., -246.],
[-125., -147., -10., -253., -242., 161., -87.]
],
[
[-153., 126., 168., 161., 126., 2., -78.],
[-137., -41., -18., 132., -89., 169., -212.],
[187., 254., 2., -70., 87., 53., -168.],
[-246., 253., -212., 59., 154., -53., 118.],
[186., -146., -154., -107., 31., 171., -246.],
[30., -147., -10., -253., -242., 161., -87.]
],
[
[108., 126., 168., 161., 126., 2., -78.],
[-1., -41., -18., 132., -89., 169., -212.],
[-9., 254., 2., -70., 87., 53., -168.],
[65., 253., -212., 59., 154., -53., 118.],
[8., -146., -154., -107., 31., 171., -246.],
[176., -147., -10., -253., -242., 161., -87.]
]
]
);
// Dim: 1
let t = Tensor::new(
&[
[
[-117_f32, -175., 69., -163.],
[200., 242., -21., -67.],
[179., 150., -126., -75.],
[-118., 38., -138., -13.],
[-221., 136., -185., 180.],
[58., 182., -204., -149.],
],
[
[3., -148., -58., -154.],
[-43., 45., -108., 4.],
[-69., -249., -71., -21.],
[80., 110., -152., -235.],
[-88., 7., 92., -250.],
[-186., 207., -242., 98.],
],
[
[238., 19., 64., -242.],
[-150., -97., 218., 58.],
[111., -233., 204., -212.],
[-242., -232., 83., 42.],
[153., 62., -251., 219.],
[-117., 36., -119., 10.],
],
[
[215., 159., -169., -27.],
[-83., 101., -88., 169.],
[-205., 93., 225., -64.],
[-162., 240., 214., 23.],
[-112., 6., 21., 245.],
[-38., 113., 93., 215.],
],
[
[91., -188., -148., 101.],
[74., 203., -35., 55.],
[-116., -130., -153., -96.],
[58., 22., -45., -194.],
[-221., -134., 73., 159.],
[-203., -254., 31., 235.],
],
[
[105., -53., 61., 186.],
[-195., 234., 75., -1.],
[51., 139., 160., -108.],
[-173., -167., 161., 19.],
[83., -246., 156., -222.],
[109., 39., -149., 137.],
],
],
device,
)?;
let ids = Tensor::new(
&[
[[4_u32, 4, 4, 2]],
[[0, 4, 4, 3]],
[[1, 5, 3, 4]],
[[0, 3, 3, 2]],
[[1, 1, 5, 2]],
[[1, 4, 5, 4]],
],
device,
)?;
let hs = t.gather(&ids, 1)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[-221., 136., -185., -75.]],
[[3., 7., 92., -235.]],
[[-150., 36., 83., 219.]],
[[215., 240., 214., -64.]],
[[74., 203., 31., -96.]],
[[-195., -246., -149., -222.]]
]
);
// Dim: 2
let t = Tensor::new(
&[
[[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]],
[[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]],
],
device,
)?;
let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?;
let hs = t.gather(&ids, 2)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[202.], [-126.], [-65.], [80.]],
[[37.], [89.], [117.], [220.]]
]
);
let t = Tensor::new(
&[
[[-21_f32, -197.], [194., 122.]],
[[255., -106.], [-191., 250.]],
[[33., -117.], [43., 10.]],
[[-130., 238.], [-217., -92.]],
],
device,
)?;
let ids = Tensor::new(
&[
[[0_u32, 1], [1, 0]],
[[1, 0], [0, 1]],
[[0, 1], [0, 1]],
[[1, 0], [1, 0]],
],
device,
)?;
let hs = t.gather(&ids, 2)?;
assert_eq!(
hs.to_vec3::<f32>()?,
&[
[[-21., -197.], [122., 194.]],
[[-106., 255.], [-191., 250.]],
[[33., -117.], [43., 10.]],
[[238., -130.], [-92., -217.]]
]
);
Ok(())
}

View File

@ -87,7 +87,7 @@ impl<'a> DatasetRandomIter<'a> {
}
}
impl Iterator for DatasetRandomIter<'_> {
impl<'a> Iterator for DatasetRandomIter<'a> {
type Item = Result<(Tensor, Tensor)>;
fn next(&mut self) -> Option<Self::Item> {

View File

@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.22.0", features = ["auto-initialize", "abi3-py311"], optional = true }
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }
@ -36,7 +36,6 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
[dev-dependencies]
anyhow = { workspace = true }
@ -66,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal", "rubato"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
@ -118,7 +117,3 @@ required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

View File

@ -1,224 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, Device, Tensor};
use candle_nn as nn;
use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel};
use clap::Parser;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
tracing_subscriber::fmt::init();
let device = candle_examples::device(args.cpu)?;
let var = load_weights(args.model, &device)?;
let clip_model = ChineseClipModel::new(var, &ChineseClipConfig::clip_vit_base_patch16())?;
tracing::info!("Transformer loaded. ");
let (pixel_values, vec_imgs) = load_images(args.images, &device)?;
tracing::info!("Images loaded. ");
let tokenizer = load_tokenizer()?;
let (input_ids, type_ids, attention_mask, text_sequences) =
tokenize_sequences(args.sequences, &tokenizer, &device)?;
tracing::info!("Computing ... ");
let (_logits_per_text, logits_per_image) = clip_model.forward(
&pixel_values,
&input_ids,
Some(&type_ids),
Some(&attention_mask),
)?;
let softmax_image = nn::ops::softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
tracing::info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
tracing::info!("Probability: {:.4}% Text: {} ", p, text_sequences[i]);
}
}
Ok(())
}
pub fn load_weights(model: Option<String>, device: &Device) -> anyhow::Result<nn::VarBuilder> {
let model_file = match model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
Ok(unsafe { nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? })
}
pub fn load_tokenizer() -> anyhow::Result<Tokenizer> {
let tokenizer_file = {
let api = hf_hub::api::sync::Api::new()?;
let repo = hf_hub::Repo::with_revision(
"OFA-Sys/chinese-clip-vit-base-patch16".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
);
let api = api.repo(repo);
api.get("tokenizer.json")?
};
Tokenizer::from_file(tokenizer_file).map_err(anyhow::Error::msg)
}
pub fn tokenize_sequences(
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Tensor, Tensor, Vec<String>)> {
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"自行车比赛".to_string(),
"两只猫咪".to_string(),
"拿着蜡烛的机器人".to_string(),
],
};
let mut input_ids = vec![];
let mut type_ids = vec![];
let mut attention_mask = vec![];
let mut max_len = 0;
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(anyhow::Error::msg)?;
input_ids.push(encoding.get_ids().to_vec());
type_ids.push(encoding.get_type_ids().to_vec());
attention_mask.push(encoding.get_attention_mask().to_vec());
if encoding.get_ids().len() > max_len {
max_len = encoding.get_ids().len();
}
}
let pad_id = *tokenizer
.get_vocab(true)
.get("[PAD]")
.ok_or(anyhow::Error::msg("No pad token"))?;
let input_ids: Vec<Vec<u32>> = input_ids
.iter_mut()
.map(|item| {
item.extend(vec![pad_id; max_len - item.len()]);
item.to_vec()
})
.collect();
let type_ids: Vec<Vec<u32>> = type_ids
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let attention_mask: Vec<Vec<u32>> = attention_mask
.iter_mut()
.map(|item| {
item.extend(vec![0; max_len - item.len()]);
item.to_vec()
})
.collect();
let input_ids = Tensor::new(input_ids, device)?;
let type_ids = Tensor::new(type_ids, device)?;
let attention_mask = Tensor::new(attention_mask, device)?;
Ok((input_ids, type_ids, attention_mask, vec_seq))
}
pub fn load_images(
images: Option<Vec<String>>,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let vec_imgs = match images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let mut images = vec![];
for path in vec_imgs.iter() {
let tensor = load_image(path, 224, device)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?.to_device(device)?;
Ok((images, vec_imgs))
}
fn load_image<T: AsRef<std::path::Path>>(
path: T,
image_size: usize,
device: &Device,
) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8().into_raw();
let img = Tensor::from_vec(img, (height, width, 3), device)?.permute((2, 0, 1))?;
let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?;
let std =
Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?;
let img = (img.to_dtype(DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)?;
Ok(img)
}

View File

@ -12,6 +12,7 @@ use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::clip;
use tokenizers::Tokenizer;
use tracing::info;
#[derive(Parser)]
struct Args {
@ -39,12 +40,15 @@ fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
// .unsqueeze(0)?;
Ok(img)
}
@ -53,16 +57,24 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
// std::env::set_var("RUST_BACKTRACE", "full");
let args = Args::parse();
tracing_subscriber::fmt::init();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
@ -77,9 +89,13 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = clip::ClipConfig::vit_base_patch32();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@ -87,29 +103,43 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
// let image = load_image(args.image, config.image_size)?.to_device(&device)?;
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = clip::ClipModel::new(vb, &config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
info!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
info!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
info!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
@ -126,6 +156,7 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
@ -138,6 +169,7 @@ pub fn tokenize_sequences(
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
@ -146,12 +178,16 @@ pub fn tokenize_sequences(
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
@ -159,6 +195,8 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -1,18 +0,0 @@
# Colpali
[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)
```
wget https://arxiv.org/pdf/1706.03762.pdf
cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
```
```
Prompt: what is position encoding?
top 3 page numbers that contain similarity to the prompt
-----------------------------------
Page: 6
Page: 11
Page: 15
-----------------------------------
```

View File

@ -1,268 +0,0 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::colpali::Model;
use candle_transformers::models::{colpali, paligemma};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use image::DynamicImage;
use pdf2image::{RenderOptionsBuilder, PDF};
use tokenizers::Tokenizer;
struct PageRetriever {
model: Model,
config: paligemma::Config,
pdf: PDF,
device: Device,
tokenizer: Tokenizer,
range: pdf2image::Pages,
batch_size: usize,
top_k: usize,
}
impl PageRetriever {
fn new(
model: Model,
config: paligemma::Config,
pdf: PDF,
tokenizer: Tokenizer,
device: &Device,
range: Option<pdf2image::Pages>,
batch_size: usize,
top_k: usize,
) -> Self {
let page_count = pdf.page_count();
Self {
model,
config,
pdf,
device: device.clone(),
tokenizer,
range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
batch_size,
top_k,
}
}
fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
let pages = self
.pdf
.render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
Ok(pages)
}
fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.device)
})
.collect::<candle::Result<Vec<_>>>()?;
let input = Tensor::stack(&token_ids, 0)?;
Ok(input)
}
fn images_to_tensor(
&self,
pages: &[DynamicImage],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for page in pages.iter() {
let img = page.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
images.push(img);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
let dtype = if self.device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let dummy_prompt: &str = "Describe the image";
let input = self.tokenize_batch(vec![prompt])?;
let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;
let pages = self.get_images_from_pdf()?;
let mut all_scores = Vec::new();
for batch in pages.chunks(self.batch_size) {
let page_images = self
.images_to_tensor(batch, self.config.vision_config.image_size)?
.to_device(&self.device)?
.to_dtype(dtype)?;
let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;
let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
let text_embeddings = self.model.forward_text(&input)?;
let scores = text_embeddings
.unsqueeze(1)?
.broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
.max(3)?
.sum(2)?;
let batch_scores: Vec<f32> = scores
.to_dtype(DType::F32)?
.to_vec2()?
.into_iter()
.flatten()
.collect();
all_scores.extend(batch_scores);
}
let mut indices: Vec<usize> = (0..all_scores.len()).collect();
indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());
let top_k_indices = indices[0..self.top_k].to_vec();
Ok(top_k_indices)
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// number of top pages to show.
#[arg(long, default_value_t = 3)]
top_k: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
pdf: String,
#[arg(long)]
start: Option<u32>,
#[arg(long)]
end: Option<u32>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "vidore/colpali-v1.2-merged".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.repo(Repo::with_revision(
"vidore/colpali".to_string(),
RepoType::Model,
"main".to_string(),
))
.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let start = std::time::Instant::now();
let config: paligemma::Config = paligemma::Config::paligemma_3b_448();
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(false)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = colpali::Model::new(&config, vb)?;
let pdf = PDF::from_file(args.pdf)?;
// check if start and end given in arg
let range = if let (Some(start), Some(end)) = (args.start, args.end) {
pdf2image::Pages::Range(start..=end)
} else {
pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
};
let mut retriever =
PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
let top_k_indices = retriever.retrieve(&args.prompt)?;
println!("Prompt: {}", args.prompt);
println!(
"top {} page numbers that contain similarity to the prompt",
retriever.top_k
);
println!("-----------------------------------");
for index in top_k_indices {
println!("Page: {:?}", index + 1);
}
println!("-----------------------------------");
Ok(())
}

View File

@ -1,3 +1,4 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};

View File

@ -44,14 +44,6 @@ struct Args {
#[arg(long, value_enum, default_value = "schnell")]
model: Model,
/// Use the slower kernels.
#[arg(long)]
use_dmmv: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
@ -73,7 +65,6 @@ fn run(args: Args) -> Result<()> {
decode_only,
model,
quantized,
..
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
@ -95,9 +86,6 @@ fn run(args: Args) -> Result<()> {
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
if let Some(seed) = args.seed {
device.set_seed(seed)?;
}
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
@ -256,7 +244,5 @@ fn run(args: Args) -> Result<()> {
fn main() -> Result<()> {
let args = Args::parse();
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.use_dmmv);
run(args)
}

View File

@ -35,26 +35,10 @@ enum Which {
V31,
V3Instruct,
V31Instruct,
V32_1b,
V32_1bInstruct,
V32_3b,
V32_3bInstruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
TinyLlama1_1BChat,
#[value(name = "SmoLM2-1.7B")]
SmolLM2_1B,
#[value(name = "SmoLM2-1.7B-Instruct")]
SmolLM2_1BInstruct,
#[value(name = "SmoLM2-360M")]
SmolLM2_360M,
#[value(name = "SmoLM2-360M-Instruct")]
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-135M")]
SmolLM2_135M,
#[value(name = "SmoLM2-135M-Instruct")]
SmolLM2_135MInstruct,
}
#[derive(Parser, Debug)]
@ -146,28 +130,15 @@ fn main() -> Result<()> {
};
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| {
let str = match args.which {
Which::V1 => "Narsil/amall-7b",
Which::V2 => "meta-llama/Llama-2-7b-hf",
Which::V3 => "meta-llama/Meta-Llama-3-8B",
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct",
Which::V31 => "meta-llama/Llama-3.1-8B",
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct",
Which::V32_1b => "meta-llama/Llama-3.2-1B",
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct",
Which::V32_3b => "meta-llama/Llama-3.2-3B",
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct",
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0",
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
};
str.to_string()
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
@ -185,22 +156,10 @@ fn main() -> Result<()> {
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::V32_3b
| Which::V32_3bInstruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::SmolLM2_360M
| Which::SmolLM2_360MInstruct
| Which::SmolLM2_135M
| Which::SmolLM2_135MInstruct
| Which::SmolLM2_1B
| Which::SmolLM2_1BInstruct
| Which::V32_1b
| Which::V32_1bInstruct
| Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;

View File

@ -17,7 +17,7 @@ pub struct Config {
impl Config {
fn vocab_size(&self) -> usize {
let pad = self.pad_vocab_size_multiple;
self.vocab_size.div_ceil(pad) * pad
(self.vocab_size + pad - 1) / pad * pad
}
fn dt_rank(&self) -> usize {

View File

@ -1,3 +1,4 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::sync::{Arc, Mutex};

View File

@ -60,6 +60,7 @@ fn load_images<T: AsRef<std::path::Path>>(
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = candle_examples::imagenet::load_image_with_std_mean(
path,
@ -69,7 +70,9 @@ fn load_images<T: AsRef<std::path::Path>>(
)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
@ -77,17 +80,24 @@ pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_name = args.which.model_name();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
let model_file = if args.use_pth {
api.get("open_clip_pytorch_model.bin")?
} else {
api.get("open_clip_model.safetensors")?
};
let tokenizer = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let config = &args.which.config();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
@ -95,7 +105,9 @@ pub fn main() -> anyhow::Result<()> {
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?;
let vb = if args.use_pth {
VarBuilder::from_pth(&model_file, DType::F32, &device)?
} else {
@ -103,15 +115,22 @@ pub fn main() -> anyhow::Result<()> {
};
let model = mobileclip::MobileClipModel::new(vb, config)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
@ -152,6 +171,7 @@ pub fn tokenize_sequences(
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
@ -165,6 +185,8 @@ pub fn tokenize_sequences(
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -1,43 +0,0 @@
# NV-Embed-v2
Candle implementation (inference only) of [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2), a text embedding model that ranks No. 1 (as of Nov 25 2024) on the [MTEB](https://huggingface.co/spaces/mteb/leaderboard) benchmark with a score of 72.31 across 56 text embedding tasks.
## Running an example: Retrieval
```bash
cargo run --example nvembed_v2 --release
> scores: [[87.4269, 0.4629],
> [ 0.9653, 86.0372]]
> Tensor[[2, 2], f32]
```
In this example, we have two queries and two passages (the corresponding answers). The output tensor represents the similarity scores between each query-passage pair. The scores are computed by taking the dot product of the query and passage embeddings and scaling the result by 100.
```rust
let queries = [
"are judo throws allowed in wrestling?",
"how to become a radiology technician in michigan?",
];
let query_instruction =
"Instruct: Given a question, retrieve passages that answer the question\nQuery: "
.to_string();
let passages = [
"Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
"Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
];
let passage_instruction = "".to_string();
```
If you already have the model and tokenizer files, you can use the `--tokenizer` and `--model-files` options to specify their full paths, instead of downloading them from the hub.
## Running an example: Sentence embedding
```bash
cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence"
> Embedding: [[ 0.0066, -0.0048, 0.0066, ..., -0.0096, 0.0119, -0.0052]]
> Tensor[[1, 4096], f32]
```
In this example, we pass a prompt to the model and it outputs the vector encoding of the prompt.
## Hardware Requirements
29.25GB at fp32
## License
CC-BY-NC-4.0. This model should not be used for any commercial purpose. Refer the [license](https://spdx.org/licenses/CC-BY-NC-4.0) for the detailed terms.

View File

@ -1,214 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use candle::{DType, IndexOp, Shape, Tensor, D};
use candle_nn::VarBuilder;
use candle_transformers::models::nvembed_v2::model::Model;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingDirection, PaddingParams, Tokenizer, TruncationParams};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
model: Option<String>,
/// Comma-separated list of model files (e.g., '/path/file1.safetensors,/path/file2.safetensors,/path/file3.safetensors')
#[arg(long)]
model_files: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(Model, tokenizers::Tokenizer)> {
let model_name = match self.model.as_ref() {
Some(model) => model.to_string(),
None => "nvidia/NV-Embed-v2".to_string(),
};
let api = Api::new()?;
let repo = api.repo(Repo::new(model_name.to_string(), RepoType::Model));
let model_files = match &self.model_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let tokenizer_file = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let _ = tokenizer
.with_padding(Some(PaddingParams {
direction: PaddingDirection::Right,
pad_id: 2,
pad_token: "</s>".to_string(),
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: 32768,
..Default::default()
}));
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device) }?;
let nvembed_model = Model::new(vb);
Ok((nvembed_model?, tokenizer))
}
}
fn encode(
model: &mut Model,
tokenizer: &Tokenizer,
examples: Vec<String>,
instruction: &str,
) -> Result<Tensor> {
let device = &model.device;
let dtype = model.dtype;
// Format input text
let eos_token = if let Some(padding) = tokenizer.get_padding() {
padding.pad_token.clone()
} else {
"".to_string()
};
let bos = "<s>".to_string();
let input_texts = examples
.iter()
.map(|input_example| format!("{bos}{instruction}{input_example}{eos_token}"))
.collect::<Vec<String>>();
// Tokenize
let encodings = tokenizer.encode_batch(input_texts, false).map_err(E::msg)?;
let input_ids_list = encodings
.iter()
.map(|encoding| {
Tensor::from_slice(
encoding.get_ids(),
Shape::from(encoding.get_ids().len()),
device,
)
})
.collect::<Result<Vec<_>, _>>()?;
let input_ids = Tensor::stack(&input_ids_list, 0)?;
// Mask out padding tokens for both embedding model and latent attention model
let attention_masks: Vec<Tensor> = encodings
.iter()
.map(|encoding| {
Tensor::from_slice(
encoding.get_attention_mask(),
Shape::from(encoding.get_attention_mask().len()),
device,
)?
.to_dtype(dtype)
})
.collect::<Result<Vec<_>, _>>()?;
let attention_mask = Tensor::stack(&attention_masks, 0)?;
// Mask out instruction tokens for latent attention model
let pool_mask = if !instruction.is_empty() {
let encoded_instruction = tokenizer.encode(instruction, false).map_err(E::msg)?;
let instruction_lens = encoded_instruction.get_tokens().len();
let zeros = Tensor::zeros(
attention_mask.i((.., ..instruction_lens))?.shape(),
dtype,
device,
)?;
let b = attention_mask.dims()[0];
attention_mask.slice_assign(&[..b, ..instruction_lens], &zeros)?
} else {
attention_mask.clone()
};
let hiddens = model
.forward(&input_ids, &attention_mask, &pool_mask)?
.squeeze(1)?;
// Normalize embedding
div_l2_norm(&hiddens)
}
fn div_l2_norm(v: &Tensor) -> Result<Tensor> {
let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?;
Ok(v.broadcast_div(&l2_norm)?)
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (mut model, tokenizer) = args.build_model_and_tokenizer()?;
if let Some(prompt) = args.prompt {
let emb = encode(&mut model, &tokenizer, vec![prompt], "")?;
println!("Embedding: {emb}");
} else {
let queries = [
"are judo throws allowed in wrestling?",
"how to become a radiology technician in michigan?",
];
let passages = [
"Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
"Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
];
let passage_instruction = "".to_string();
let query_instruction =
"Instruct: Given a question, retrieve passages that answer the question\nQuery: "
.to_string();
let passages: Vec<String> = passages.iter().map(|s| s.to_string()).collect();
let queries: Vec<String> = queries.iter().map(|s| s.to_string()).collect();
let emb_query = encode(&mut model, &tokenizer, queries, &query_instruction)?;
let emb_passage = encode(&mut model, &tokenizer, passages, &passage_instruction)?;
let scores = (emb_query.matmul(&emb_passage.t()?)? * 100.0)?;
println!("scores: {scores}");
}
Ok(())
}

View File

@ -1,28 +0,0 @@
# PaliGemma
[HuggingFace Model Card](https://huggingface.co/google/paligemma-3b-pt-224) -
[Model Page](https://ai.google.dev/gemma/docs/paligemma)
```bash
cargo run --features cuda --release --example paligemma -- \
--prompt "caption fr" --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
```
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0]
loaded the model in 1.267744448s
caption fr. Un groupe de cyclistes qui sont dans la rue.
13 tokens generated (56.52 token/s)
```
```bash
cargo run --features cuda --release --example paligemma -- \
--prompt "caption fr" --image candle-examples/examples/flux/assets/flux-robot.jpg
```
```
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0]
loaded the model in 1.271492621s
caption fr une image d' un robot sur la plage avec le mot rouillé
15 tokens generated (62.78 token/s)
```

View File

@ -1,276 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::paligemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
image: Tensor,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
image: Tensor,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
image,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = if index > 0 {
self.model.forward(&input)?
} else {
self.model.setup(&self.image, &input)?
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
image: String,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "google/paligemma-3b-mix-224".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config = Config::paligemma_3b_224();
let image = load_image(&args.image, config.vision_config.image_size)?
.to_device(&device)?
.to_dtype(dtype)?
.unsqueeze(0)?;
println!("loaded image with shape {:?}", image);
let start = std::time::Instant::now();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
image,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
let prompt = format!("{}\n", args.prompt);
pipeline.run(&prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,28 +0,0 @@
# pixtral
Pixtral-12B is a 12B text+vision model.
[Blog Post](https://mistral.ai/news/pixtral-12b/) -
[HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) -
[HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b).
```bash
cargo run --profile=release-with-debug --features cuda --example pixtral -- \
--image candle-examples/examples/flux/assets/flux-robot.jpg
```
```
Describe the image.
The image depicts a charming, rustic robot standing on a sandy beach at sunset.
The robot has a vintage, steampunk aesthetic with visible gears and mechanical
parts. It is holding a small lantern in one hand, which emits a warm glow, and
its other arm is extended forward as if reaching out or guiding the way. The
robot's body is adorned with the word "RUST" in bright orange letters, adding to
its rustic theme.
The background features a dramatic sky filled with clouds, illuminated by the
setting sun, casting a golden hue over the scene. Gentle waves lap against the
shore, creating a serene and picturesque atmosphere. The overall mood of the
image is whimsical and nostalgic, evoking a sense of adventure and tranquility.
```

View File

@ -1,327 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::pixtral::{vision_model, Config, Model};
use candle::{DType, Device, Module, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
image: Tensor,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
image: Tensor,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
image,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let get_token = |v| match self.tokenizer.get_token(v) {
Some(token) => Ok(token),
None => anyhow::bail!("cannot find the {v} token"),
};
let bos_token = get_token("<s>")?;
let eos_token = get_token("</s>")?;
let inst_token = get_token("[INST]")?;
let end_inst_token = get_token("[/INST]")?;
let img_break = get_token("[IMG_BREAK]")?;
let img_end = get_token("[IMG_END]")?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let logits = if index > 0 {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
self.model.lm_forward(&input)?
} else {
let (_b, _c, h, w) = self.image.dims4()?;
let h = h / self.model.patch_size;
let w = w / self.model.patch_size;
let image_embeds = self.model.encode_image(&self.image)?;
println!("generated image embeddings {image_embeds:?}");
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let break_embeds = {
let input = Tensor::new(&[img_break], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let start_embeds = {
let mut in_tokens = vec![bos_token, inst_token];
in_tokens.extend_from_slice(tokens.as_slice());
let input = Tensor::new(in_tokens.as_slice(), &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let end_embeds = {
let input =
Tensor::new(&[img_end, end_inst_token], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let mut input_embeds = vec![start_embeds];
for h_idx in 0..h {
if h_idx > 0 {
input_embeds.push(break_embeds.clone())
}
let row = image_embeds.narrow(1, h_idx * w, w)?;
input_embeds.push(row);
}
input_embeds.push(end_embeds);
let input_embeds = Tensor::cat(&input_embeds, 1)?;
self.model.lm_forward_embeds(&input_embeds)?
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Describe the image.\n")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
image: String,
#[arg(long)]
vision_only: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "mistral-community/pixtral-12b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let device = candle_examples::device(args.cpu)?;
let dtype = if device.supports_bf16() && !args.vision_only {
DType::BF16
} else {
DType::F32
};
let config: Config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let image = if args.image.ends_with(".safetensors") {
match candle::safetensors::load(&args.image, &device)?.remove("img") {
None => anyhow::bail!("no img tensor in {}", args.image),
Some(v) => v,
}
} else {
candle_examples::imagenet::load_image_with_std_mean(
&args.image,
1024,
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.261_302_6, 0.275_777_1],
)?
};
let image = image.to_device(&device)?.unsqueeze(0)?;
println!("loaded image with shape {:?}", image);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
if args.vision_only {
let start = std::time::Instant::now();
let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?;
println!("loaded the model in {:?}", start.elapsed());
let embs = model.forward(&image)?;
println!("EMBS\n{embs}");
} else {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
image,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
}
Ok(())
}

View File

@ -71,10 +71,6 @@ enum Which {
L8b,
#[value(name = "phi3")]
Phi3,
#[value(name = "SmoLM2-360M-Instruct")]
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-1.7B-Instruct")]
SmolLM2_1BInstruct,
}
impl Which {
@ -92,9 +88,7 @@ impl Which {
| Self::Leo7b
| Self::Leo13b
| Self::L8b
| Self::Phi3
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct => false,
| Self::Phi3 => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -130,8 +124,6 @@ impl Which {
| Self::OpenChat35
| Self::Starling7bAlpha
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
@ -158,8 +150,6 @@ impl Which {
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3 => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
@ -189,8 +179,6 @@ impl Which {
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B",
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
}
}
}
@ -355,14 +343,6 @@ impl Args {
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
),
Which::SmolLM2_360MInstruct => (
"HuggingFaceTB/SmolLM2-360M-Instruct-GGUF",
"smollm2-360m-instruct-q8_0.gguf",
),
Which::SmolLM2_1BInstruct => (
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
"smollm2-1.7b-instruct-q4_k_m.gguf",
),
};
let revision = if self.which == Which::Phi3 {
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
@ -475,8 +455,6 @@ fn main() -> anyhow::Result<()> {
| Which::Leo7b
| Which::Leo13b
| Which::L8b
| Which::SmolLM2_1BInstruct
| Which::SmolLM2_360MInstruct
| Which::Phi3 => 1,
Which::Mixtral
| Which::MixtralInstruct
@ -595,7 +573,6 @@ fn main() -> anyhow::Result<()> {
}
let eos_token = match args.which {
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
Which::L8b => "<|end_of_text|>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",

View File

@ -1,4 +1,5 @@
use std::collections::VecDeque;
use std::fmt::Display;
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
use candle_nn::{
@ -166,7 +167,6 @@ fn track(
Ok(())
}
#[allow(unused)]
struct Actor<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
@ -211,7 +211,7 @@ impl Actor<'_> {
let target_network = make_network("target-actor")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0)?;
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
Ok(Self {
varmap,
@ -244,7 +244,6 @@ impl Actor<'_> {
}
}
#[allow(unused)]
struct Critic<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
@ -288,7 +287,7 @@ impl Critic<'_> {
let target_network = make_network("target-critic")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0)?;
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
Ok(Self {
varmap,
@ -323,7 +322,6 @@ impl Critic<'_> {
}
}
#[allow(unused)]
#[allow(clippy::upper_case_acronyms)]
pub struct DDPG<'a> {
actor: Actor<'a>,

View File

@ -1,3 +1,4 @@
#![allow(unused)]
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
use candle::{Device, Result, Tensor};
use pyo3::prelude::*;

View File

@ -1,3 +1,5 @@
#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

View File

@ -14,7 +14,7 @@ fn new_model(
) -> Result<(impl Module, VarMap)> {
let input_size = input_shape.iter().product();
let varmap = VarMap::new();
let mut varmap = VarMap::new();
let var_builder = VarBuilder::from_varmap(&varmap, dtype, device);
let model = seq()

View File

@ -1,8 +1,9 @@
#![allow(unused)]
//! Vectorized version of the gym environment.
use candle::{DType, Device, Result, Tensor};
use pyo3::prelude::*;
use pyo3::types::PyDict;
#[allow(unused)]
#[derive(Debug)]
pub struct Step {
pub obs: Tensor,
@ -10,7 +11,6 @@ pub struct Step {
pub is_done: Tensor,
}
#[allow(unused)]
pub struct VecGymEnv {
env: PyObject,
action_space: usize,
@ -21,7 +21,6 @@ fn w(res: PyErr) -> candle::Error {
candle::Error::wrap(res)
}
#[allow(unused)]
impl VecGymEnv {
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
Python::with_gil(|py| {

View File

@ -1,24 +0,0 @@
## SigLIP
SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmoid based loss,
[HuggingFace](https://huggingface.co/google/siglip-base-patch16-224).
### Running an example
```
$ cargo run --features cuda -r --example siglip -
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
Probability: 0.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 100.0000% Text: a robot holding a candle
Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
Probability: 100.0000% Text: a cycling race
Probability: 0.0000% Text: a photo of two cats
Probability: 0.0000% Text: a robot holding a candle
```

View File

@ -1,153 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::siglip;
use tokenizers::Tokenizer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long, use_value_delimiter = true)]
images: Option<Vec<String>>,
#[arg(long)]
cpu: bool,
#[arg(long, use_value_delimiter = true)]
sequences: Option<Vec<String>>,
}
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
let img = image::ImageReader::open(path)?.decode()?;
let (height, width) = (image_size, image_size);
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
Ok(img)
}
fn load_images<T: AsRef<std::path::Path>>(
paths: &Vec<T>,
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for path in paths {
let tensor = load_image(path, image_size)?;
images.push(tensor);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let tokenizer = get_tokenizer(args.tokenizer)?;
let config = siglip::Config::base_patch16_224();
let device = candle_examples::device(args.cpu)?;
let vec_imgs = match args.images {
Some(imgs) => imgs,
None => vec![
"candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(),
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
],
};
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = siglip::Model::new(&config, vb)?;
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
println!("softmax_image_vec: {:?}", softmax_image_vec);
let probability_vec = softmax_image_vec
.iter()
.map(|v| v * 100.0)
.collect::<Vec<f32>>();
let probability_per_image = probability_vec.len() / vec_imgs.len();
for (i, img) in vec_imgs.iter().enumerate() {
let start = i * probability_per_image;
let end = start + probability_per_image;
let prob = &probability_vec[start..end];
println!("\n\nResults for image: {}\n", img);
for (i, p) in prob.iter().enumerate() {
println!("Probability: {:.4}% Text: {} ", p, vec_seq[i]);
}
}
Ok(())
}
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
Tokenizer::from_file(tokenizer).map_err(E::msg)
}
pub fn tokenize_sequences(
config: &siglip::Config,
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = config.text_config.pad_token_id;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
"a cycling race".to_string(),
"a photo of two cats".to_string(),
"a robot holding a candle".to_string(),
],
};
let mut tokens = vec![];
for seq in vec_seq.clone() {
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = config.text_config.max_position_embeddings;
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
if len_diff > 0 {
token_vec.extend(vec![pad_id; len_diff]);
}
}
let input_ids = Tensor::new(tokens, device)?;
Ok((input_ids, vec_seq))
}

View File

@ -1,28 +0,0 @@
# candle-splade
SPLADE is a neural retrieval model which learns query/document sparse expansion via the BERT MLM head and sparse regularization. Sparse representations benefit from several advantages compared to dense approaches: efficient use of inverted index, explicit lexical match, interpretability... They also seem to be better at generalizing on out-of-domain data. In this example we can do the following two tasks:
- Compute sparse embedding for a given query.
- Compute similarities between a set of sentences using sparse embeddings.
## Sparse Sentence embeddings
SPLADE is used to compute the sparse embedding for a given query. The model weights
are downloaded from the hub on the first run. This makes use of the BertForMaskedLM model.
```bash
cargo run --example splade --release -- --prompt "Here is a test sentence"
> "the out there still house inside position outside stay standing hotel sitting dog animal sit bird cat statue cats"
> [0.10270107, 0.269471, 0.047469813, 0.0016636598, 0.05394874, 0.23105666, 0.037475716, 0.45949644, 0.009062732, 0.06790692, 0.0327835, 0.33122346, 0.16863061, 0.12688516, 0.340983, 0.044972017, 0.47724655, 0.01765311, 0.37331146]
```
```bash
cargo run --example splade --release --features
> score: 0.47 'The new movie is awesome' 'The new movie is so great'
> score: 0.43 'The cat sits outside' 'The cat plays in the garden'
> score: 0.14 'I love pasta' 'Do you like pizza?'
> score: 0.11 'A man is playing guitar' 'The cat plays in the garden'
> score: 0.05 'A man is playing guitar' 'A woman watches TV'
```

View File

@ -1,210 +0,0 @@
use std::path::PathBuf;
use anyhow::{Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{self, BertForMaskedLM, Config};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,
// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,
// Path to the config file.
#[arg(long)]
config_file: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "prithivida/Splade_PP_en_v1".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e)));
}
},
},
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(args.cpu)?;
let dtype = bert::DTYPE;
let vb = if weights_filename.ends_with("model.safetensors") {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device).unwrap() }
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap()
};
let model = BertForMaskedLM::load(vb, &config)?;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids, None)?;
let vec = Tensor::log(
&Tensor::try_from(1.0)?
.to_dtype(dtype)?
.to_device(&device)?
.broadcast_add(&ys.relu()?)?,
)?
.max(1)?;
let vec = normalize_l2(&vec)?;
let vec = vec.squeeze(0)?.to_vec1::<f32>()?;
let indices = (0..vec.len())
.filter(|&i| vec[i] != 0.0)
.map(|x| x as u32)
.collect::<Vec<_>>();
let tokens = tokenizer.decode(&indices, true).unwrap();
println!("{tokens:?}");
let values = indices.iter().map(|&i| vec[i as usize]).collect::<Vec<_>>();
println!("{values:?}");
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Ok(Tensor::new(tokens.as_slice(), &device)?)
})
.collect::<Result<Vec<_>>>()?;
let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Ok(Tensor::new(tokens.as_slice(), &device)?)
})
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
let vector = Tensor::log(
&Tensor::try_from(1.0)?
.to_dtype(dtype)?
.to_device(&device)?
.broadcast_add(&ys.relu()?)?,
)?;
let vector = vector
.broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(dtype)?)?
.max(1)?;
let vec = normalize_l2(&vector)?;
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = vec.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = vec.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -1,71 +0,0 @@
# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3/3.5
![](assets/stable-diffusion-3.jpg)
*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*, generated by Stable Diffusion 3 Medium
Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.
- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
- [research paper](https://arxiv.org/pdf/2403.03206)
- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)
Stable Diffusion 3.5 is a family of text-to-image models with latest improvements:
- [announcement blog post](https://stability.ai/news/introducing-stable-diffusion-3-5)
It has three variants:
- [Stable Diffusion 3.5 Large](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) @ 8.1b params, with scaled and slightly modified MMDiT architecture.
- [Stable Diffusion 3.5 Large Turbo](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo) distilled version that enables 4-step inference.
- [Stable Diffusion 3.5 Medium](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) @ 2.5b params, with improved MMDiT-X architecture.
## Getting access to the weights
The weights of Stable Diffusion 3/3.5 is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the repos on HuggingFace Hub to gain access to the weights for your HuggingFace account.
To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli):
```shell
huggingface-cli login
```
and you will be prompted to enter your token.
On the first run, the weights will be automatically downloaded from the Huggingface Hub. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally.
## Running the model
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- \
--which 3-medium --height 1024 --width 1024 \
--prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
```
To use different models, changed the value of `--which` option. (Possible values: `3-medium`, `3.5-large`, `3.5-large-turbo` and `3.5-medium`).
To display other options available,
```shell
cargo run --example stable-diffusion-3 --release --features=cuda -- --help
```
If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`.
```shell
cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ...
```
## Performance Benchmark
Below benchmark is done with Stable Diffusion 3 Medium by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).
System specs (Desktop PCIE 5 x8/x8 dual-GPU setup):
- Operating System: Ubuntu 23.10
- CPU: i9 12900K w/o overclocking.
- RAM: 64G dual-channel DDR5 @ 4800 MT/s
| Speed (iter/s) | w/o flash-attn | w/ flash-attn |
| -------------- | -------------- | ------------- |
| RTX 3090 Ti | 0.83 | 2.15 |
| RTX 4090 | 1.72 | 4.06 |

Binary file not shown.

Before

Width:  |  Height:  |  Size: 81 KiB

View File

@ -1,234 +0,0 @@
use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5};
use std::path::PathBuf;
use tokenizers::tokenizer::Tokenizer;
struct ClipWithTokenizer {
clip: stable_diffusion::clip::ClipTextTransformer,
config: stable_diffusion::clip::Config,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl ClipWithTokenizer {
fn new(
vb: candle_nn::VarBuilder,
config: stable_diffusion::clip::Config,
tokenizer_path: &str,
max_position_embeddings: usize,
) -> Result<Self> {
let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
let path_buf = hf_hub::api::sync::Api::new()?
.model(tokenizer_path.to_string())
.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
"Failed to serialize huggingface PathBuf of CLIP tokenizer",
))?)
.map_err(E::msg)?;
Ok(Self {
clip,
config,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let pad_id = match &self.config.pad_with {
Some(padding) => *self
.tokenizer
.get_vocab(true)
.get(padding.as_str())
.ok_or(E::msg("Failed to tokenize CLIP padding."))?,
None => *self
.tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
};
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let eos_position = tokens.len() - 1;
while tokens.len() < self.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let (text_embeddings, text_embeddings_penultimate) = self
.clip
.forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
Ok((text_embeddings_penultimate, text_embeddings_pooled))
}
}
struct T5WithTokenizer {
t5: t5::T5EncoderModel,
tokenizer: Tokenizer,
max_position_embeddings: usize,
}
impl T5WithTokenizer {
fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok(Self {
t5: model,
tokenizer,
max_position_embeddings,
})
}
fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<Tensor> {
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(self.max_position_embeddings, 0);
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.t5.forward_dt(&input_token_ids, Some(DType::F32))?;
Ok(embeddings)
}
}
pub struct StableDiffusion3TripleClipWithTokenizer {
clip_l: ClipWithTokenizer,
clip_g: ClipWithTokenizer,
clip_g_text_projection: candle_nn::Linear,
t5: T5WithTokenizer,
}
impl StableDiffusion3TripleClipWithTokenizer {
pub fn new_split(
clip_g_file: &PathBuf,
clip_l_file: &PathBuf,
t5xxl_file: &PathBuf,
device: &candle::Device,
) -> Result<Self> {
let vb_clip_g = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)?
};
let vb_clip_l = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
};
let vb_t5 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F16, device)?
};
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_clip_l,
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?;
let clip_g = ClipWithTokenizer::new(
vb_clip_g,
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn new(vb: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb.pp("clip_l.transformer"),
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;
let clip_g = ClipWithTokenizer::new(
vb.pp("clip_g.transformer"),
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;
let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb.pp("clip_g.transformer.text_projection"))?;
let t5 = T5WithTokenizer::new(vb.pp("t5xxl.transformer"), max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}
pub fn encode_text_to_embedding(
&mut self,
prompt: &str,
device: &candle::Device,
) -> Result<(Tensor, Tensor)> {
let (clip_l_embeddings, clip_l_embeddings_pooled) =
self.clip_l.encode_text_to_embedding(prompt, device)?;
let (clip_g_embeddings, clip_g_embeddings_pooled) =
self.clip_g.encode_text_to_embedding(prompt, device)?;
let clip_g_embeddings_pooled = self
.clip_g_text_projection
.forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
.squeeze(0)?;
let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
.unsqueeze(0)?;
let clip_embeddings_concat = Tensor::cat(
&[&clip_l_embeddings, &clip_g_embeddings],
D::Minus1,
)?
.pad_with_zeros(D::Minus1, 0, 2048)?;
let t5_embeddings = self
.t5
.encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
Ok((context, y))
}
}

View File

@ -1,273 +0,0 @@
mod clip;
mod sampling;
mod vae;
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
use crate::clip::StableDiffusion3TripleClipWithTokenizer;
use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
use anyhow::{Ok, Result};
use clap::Parser;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3-medium")]
V3Medium,
#[value(name = "3.5-large")]
V3_5Large,
#[value(name = "3.5-large-turbo")]
V3_5LargeTurbo,
#[value(name = "3.5-medium")]
V3_5Medium,
}
impl Which {
fn is_3_5(&self) -> bool {
match self {
Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true,
}
}
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A cute rusty robot holding a candle torch in its hand, \
with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \
bright background, high quality, 4k"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Use flash_attn to accelerate attention operation in the MMDiT.
#[arg(long)]
use_flash_attn: bool,
/// The height in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
height: usize,
/// The width in pixels of the generated image.
#[arg(long, default_value_t = 1024)]
width: usize,
/// The model to use.
#[arg(long, default_value = "3-medium")]
which: Which,
/// The seed to use when generating random samples.
#[arg(long)]
num_inference_steps: Option<usize>,
/// CFG scale.
#[arg(long)]
cfg_scale: Option<f64>,
/// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,
/// Use Skip Layer Guidance (SLG) for the sampling.
/// Currently only supports Stable Diffusion 3.5 Medium.
#[arg(long)]
use_slg: bool,
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
tracing,
use_flash_attn,
height,
width,
num_inference_steps,
cfg_scale,
time_shift,
seed,
which,
use_slg,
} = Args::parse();
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(cpu)?;
let default_inference_steps = match which {
Which::V3_5Large => 28,
Which::V3_5LargeTurbo => 4,
Which::V3_5Medium => 28,
Which::V3Medium => 28,
};
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which {
Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0,
Which::V3_5Medium => 4.0,
Which::V3Medium => 4.0,
};
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
let api = hf_hub::api::sync::Api::new()?;
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo_for_text_encoders = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
// Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually
// placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo.
// To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors
// under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions
// to get the monolithic text encoders. This is not a trivial task.
// Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium,
// which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors.
// so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository.
// TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders.
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let sai_repo_for_mmdit = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = {
let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3_5Medium => "sd3.5_medium.safetensors",
Which::V3Medium => unreachable!(),
};
sai_repo_for_mmdit.get(model_file)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file,
&clip_l_file,
&t5xxl_file,
&device,
)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
};
match which {
Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb),
Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb),
Which::V3Medium => unreachable!(),
}
} else {
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new(vb.pp("text_encoders"))?;
(MMDiTConfig::sd3_medium(), triple, vb)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
// Drop the text model early to avoid using too much memory.
drop(triple);
let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?;
if let Some(seed) = seed {
device.set_seed(seed)?;
}
let slg_config = if use_slg {
match which {
// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
scale: 2.5,
start: 0.01,
end: 0.2,
layers: vec![7, 8, 9],
}),
_ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
}
} else {
None
};
let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new(
&mmdit_config,
use_flash_attn,
vb.pp("model.diffusion_model"),
)?;
sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
slg_config,
)?
};
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);
let img = {
let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
};
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Ok(())
}

View File

@ -1,83 +0,0 @@
use anyhow::{Ok, Result};
use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::flux;
use candle_transformers::models::mmdit::model::MMDiT;
pub struct SkipLayerGuidanceConfig {
pub scale: f64,
pub start: f64,
pub end: f64,
pub layers: Vec<usize>,
}
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
mmdit: &MMDiT,
y: &Tensor,
context: &Tensor,
num_inference_steps: usize,
cfg_scale: f64,
time_shift: f64,
height: usize,
width: usize,
slg_config: Option<SkipLayerGuidanceConfig>,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
.map(|x| x as f64 / num_inference_steps as f64)
.rev()
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();
for (step, window) in sigmas.windows(2).enumerate() {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
None,
)?;
let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
if let Some(slg_config) = slg_config.as_ref() {
if (num_inference_steps as f64) * slg_config.start < (step as f64)
&& (step as f64) < (num_inference_steps as f64) * slg_config.end
{
let slg_noise_pred = mmdit.forward(
&x,
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
&y.i(..1)?,
&context.i(..1)?,
Some(&slg_config.layers),
)?;
guidance = (guidance
+ (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
}
}
x = (x + (guidance * (*s_prev - *s_curr))?)?;
}
Ok(x)
}
// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper
// https://arxiv.org/pdf/2403.03206
// Following the implementation in ComfyUI:
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/
// comfy/model_sampling.py#L181
fn time_snr_shift(alpha: f64, t: f64) -> f64 {
alpha * t / (1.0 + (alpha - 1.0) * t)
}
fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> {
Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)?
- ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?)
}

View File

@ -1,93 +0,0 @@
use anyhow::{Ok, Result};
use candle_transformers::models::stable_diffusion::vae;
pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
let config = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 16,
norm_num_groups: 32,
use_quant_conv: false,
use_post_quant_conv: false,
};
Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
}
pub fn sd3_vae_vb_rename(name: &str) -> String {
let parts: Vec<&str> = name.split('.').collect();
let mut result = Vec::new();
let mut i = 0;
while i < parts.len() {
match parts[i] {
"down_blocks" => {
result.push("down");
}
"mid_block" => {
result.push("mid");
}
"up_blocks" => {
result.push("up");
match parts[i + 1] {
// Reverse the order of up_blocks.
"0" => result.push("3"),
"1" => result.push("2"),
"2" => result.push("1"),
"3" => result.push("0"),
_ => {}
}
i += 1; // Skip the number after up_blocks.
}
"resnets" => {
if i > 0 && parts[i - 1] == "mid_block" {
match parts[i + 1] {
"0" => result.push("block_1"),
"1" => result.push("block_2"),
_ => {}
}
i += 1; // Skip the number after resnets.
} else {
result.push("block");
}
}
"downsamplers" => {
result.push("downsample");
i += 1; // Skip the 0 after downsamplers.
}
"conv_shortcut" => {
result.push("nin_shortcut");
}
"attentions" => {
if parts[i + 1] == "0" {
result.push("attn_1")
}
i += 1; // Skip the number after attentions.
}
"group_norm" => {
result.push("norm");
}
"query" => {
result.push("q");
}
"key" => {
result.push("k");
}
"value" => {
result.push("v");
}
"proj_attn" => {
result.push("proj_out");
}
"conv_norm_out" => {
result.push("norm_out");
}
"upsamplers" => {
result.push("upsample");
i += 1; // Skip the 0 after upsamplers.
}
part => result.push(part),
}
i += 1;
}
result.join(".")
}

View File

@ -1,65 +0,0 @@
# candle-stella-en-v5: Implementation of [stella_en_1.5B_v5](https://huggingface.co/dunzhang/stella_en_1.5B_v5) embedding model
As of 7th Oct 2024, *Stella_en_1.5B_v5* is one of the top ranking model on `retrieval` and `reranking` tasks in [MTEB](https://huggingface.co/spaces/mteb/leaderboard) leaderboard.
[Model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) on the HuggingFace Hub.
## Running the example
Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
> Tensor[[1, 1024], f32]
```
Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling multiple embedding dimensions.
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
```bash
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b
>
> Score: 0.8178786
> Query: What are some ways to reduce stress?
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
> stress from building up.
>
>
> Score: 0.7853528
> Query: What are the benefits of drinking green tea?
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m
>
> Score: 0.8397539
> Query: What are some ways to reduce stress?
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
> stress from building up.
>
>
>
> Score: 0.809545
> Query: What are the benefits of drinking green tea?
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
```
## Supported options:
- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`.
- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.

View File

@ -1,387 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::path::Path;
use anyhow::{anyhow, Error as E, Result};
use clap::Parser;
use candle_transformers::models::stella_en_v5::{
Config, EmbedDim as StellaEmbedDim, EmbeddingModel,
};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo};
use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};
struct Embedding {
model: EmbeddingModel,
device: Device,
tokenizer: Tokenizer,
}
impl Embedding {
fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self {
Self {
model,
tokenizer,
device: device.clone(),
}
}
fn encode(&mut self, task: EncodeTask, text: Option<String>) -> Result<()> {
// Just shocasing embeddings, this has no real value
if let Some(text) = text {
let qry = task.query_preproc(&[text]);
let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?;
let shape = (1, encoding.len());
let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?;
let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?;
let result = self.model.forward(&input, &mask)?;
println!("embeddings: {result}");
} else {
// Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers)
let queries = [
"What are some ways to reduce stress?".to_string(),
"What are the benefits of drinking green tea?".to_string(),
];
let docs = [
"There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(),
"Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(),
];
// We only encode the queries and not the data
let qry = task.query_preproc(&queries);
let mut qry_encoded = self
.tokenizer
.encode_batch(qry, true)
.map_err(|e| anyhow!(e))?;
let mut docs_encoded = self
.tokenizer
.encode_batch(docs.to_vec(), true)
.map_err(|e| anyhow!(e))?;
let qry_embed = {
// Now, we generate the tensors for the `input` and `mask`
let shape = (qry_encoded.len(), qry_encoded[1].len());
let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?;
let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?;
for (i, e) in qry_encoded.drain(..).enumerate() {
let input_id =
Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?;
let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)?
.to_dtype(DType::U8)?
.unsqueeze(0)?;
ids =
ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?;
masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?;
}
// Let's generate the embeddings for the query, we are going to be normalizing the result.
// For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data
self.model.forward_norm(&ids, &masks)?
};
let doc_embed = {
let shape = (docs_encoded.len(), docs_encoded[1].len());
let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?;
let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?;
for (i, e) in docs_encoded.drain(..).enumerate() {
let input_id =
Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?;
let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)?
.to_dtype(DType::U8)?
.unsqueeze(0)?;
ids =
ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?;
masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?;
}
// Let's generate the embeddings for the query, we are going to be normalizing the result.
// For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data
self.model.forward_norm(&ids, &masks)?
};
println!(
"Embed shapes:\nQuery: {:?}\nDocs: {:?}",
qry_embed.shape(),
doc_embed.shape()
); // [2, 1024] for head dim `1024`
// a matmul to generate the `similarity` score
let res = qry_embed.matmul(&doc_embed.t()?)?;
for (k, v) in queries.iter().enumerate() {
let tnsr = res.get(k)?;
let max = tnsr.argmax(0)?.to_scalar::<u32>()?;
println!(
"\nScore: {}\nQuery: {}\nAnswer: {}\n\n",
tnsr.get(max as usize)?.to_scalar::<f32>()?,
v,
docs[k]
);
}
}
Ok(())
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum EmbedDim {
#[value(name = "256")]
Dim256,
#[value(name = "768")]
Dim768,
#[value(name = "1024")]
Dim1024,
#[value(name = "2048")]
Dim2048,
#[value(name = "4096")]
Dim4096,
#[value(name = "6144")]
Dim6144,
#[value(name = "8192")]
Dim8192,
}
impl EmbedDim {
/// Returns dir path to the embed head weights int he repo
pub fn embed_dim_default_dir(&self) -> &'static str {
match self {
Self::Dim256 => "2_Dense_256",
Self::Dim768 => "2_Dense_768",
Self::Dim1024 => "2_Dense_1024",
Self::Dim2048 => "2_Dense_2048",
Self::Dim4096 => "2_Dense_4096",
Self::Dim6144 => "2_Dense_6144",
Self::Dim8192 => "2_Dense_8192",
}
}
/// Resolves the `EmbedDim` for given variant
pub fn embed_dim(&self) -> StellaEmbedDim {
match self {
Self::Dim256 => StellaEmbedDim::Dim256,
Self::Dim768 => StellaEmbedDim::Dim768,
Self::Dim1024 => StellaEmbedDim::Dim1024,
Self::Dim2048 => StellaEmbedDim::Dim2048,
Self::Dim4096 => StellaEmbedDim::Dim4096,
Self::Dim6144 => StellaEmbedDim::Dim6144,
Self::Dim8192 => StellaEmbedDim::Dim8192,
}
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
pub enum EncodeTask {
/// `s2p` is the `retrieval` task
/// Default in this example
#[value(name = "s2p")]
S2P,
/// `s2s` is the semantic similarity task
#[value(name = "s2s")]
S2S,
}
impl EncodeTask {
/// Preprocess a set of inputs basef on a template suggested by the model authors
/// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction
pub fn query_preproc(&self, txt: &[String]) -> Vec<String> {
let instruct = match self {
Self::S2P => {
"Given a web search query, retrieve relevant passages that answer the query."
}
Self::S2S => "Retrieve semantically similar text.",
};
txt.iter()
.map(|s| format!("Instruct: {instruct}\nQuery: {s}"))
.collect::<Vec<_>>()
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1.5b")]
Large,
#[value(name = "400m")]
Small,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
which: Which,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
query: Option<String>,
#[arg(long, default_value = "1024")]
embed_dim: Option<EmbedDim>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
base_weight_files: Option<String>,
#[arg(long)]
embed_head_weight_files: Option<String>,
/// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5)
/// `s2s`: Semantic textual similarity
/// `s2p`: Retrieval task - `Default` in this example
#[arg(long, default_value = "s2p")]
task: Option<EncodeTask>,
}
// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
if which == Which::Large {
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
} else {
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
..Default::default()
}));
}
Ok(tokenizer)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let start = std::time::Instant::now();
let api = Api::new()?;
let embed_dim = match args.embed_dim {
Some(d) => d,
None => EmbedDim::Dim1024,
};
let (repo, cfg) = match args.which {
Which::Large => (
"dunzhang/stella_en_1.5B_v5",
Config::new_1_5_b_v5(embed_dim.embed_dim()),
),
Which::Small => (
"dunzhang/stella_en_400M_v5",
Config::new_400_m_v5(embed_dim.embed_dim()),
),
};
let repo = api.repo(Repo::model(repo.to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
// Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights
// E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo
let base_weight_files = match args.base_weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
let embed_weight_files = match args.embed_head_weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir());
vec![repo.get(&head_w_path)?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = DType::F32;
let base_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? };
// Embedding layer is always built on F32 for accuracy
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut embedding = Embedding::new(model, tokenizer, &device);
let task = args.task.map_or(EncodeTask::S2P, |t| t);
embedding.encode(task, args.query)
}

View File

@ -10,6 +10,7 @@ use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use std::iter;
use tokenizers::Tokenizer;
mod multilingual;
@ -17,6 +18,7 @@ mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex};
pub enum Model {
Normal(m::model::Whisper),
@ -389,7 +391,6 @@ enum WhichModel {
Large,
LargeV2,
LargeV3,
LargeV3Turbo,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
@ -406,7 +407,6 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::LargeV3Turbo
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
@ -427,7 +427,6 @@ impl WhichModel {
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
@ -480,10 +479,6 @@ struct Args {
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
/// The input device to use.
#[arg(long)]
device: Option<String>,
}
pub fn main() -> Result<()> {
@ -548,12 +543,13 @@ pub fn main() -> Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
let mut decoder = Decoder::new(
let language_token = None;
let mut dc = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
/* language_token */ None,
language_token,
args.task,
args.timestamps,
args.verbose,
@ -569,83 +565,47 @@ pub fn main() -> Result<()> {
// Set up the input device and stream with the default input config.
let host = cpal::default_host();
let audio_device = match args.device.as_ref() {
None => host.default_input_device(),
Some(device) => host
.input_devices()?
.find(|x| x.name().map_or(false, |y| &y == device)),
let _device = "default";
let _device = if _device == "default" {
host.default_input_device()
} else {
host.input_devices()?
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
}
.expect("failed to find the audio input device");
.expect("failed to find input device");
let audio_config = audio_device
let _config = _device
.default_input_config()
.expect("Failed to get default input config");
println!("audio config {audio_config:?}");
let channel_count = audio_config.channels() as usize;
let in_sample_rate = audio_config.sample_rate().0 as usize;
let resample_ratio = 16000. / in_sample_rate as f64;
let mut resampler = rubato::FastFixedIn::new(
resample_ratio,
10.,
rubato::PolynomialDegree::Septic,
1024,
1,
)?;
let (tx, rx) = std::sync::mpsc::channel();
let stream = audio_device.build_input_stream(
&audio_config.config(),
move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
let pcm = pcm
.iter()
.step_by(channel_count)
.copied()
.collect::<Vec<f32>>();
if !pcm.is_empty() {
tx.send(pcm).unwrap()
}
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
let channel_count = _config.channels() as usize;
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
let audio_ring_buffer_2 = audio_ring_buffer.clone();
std::thread::spawn(move || loop {
let data = record_audio(&_device, &_config, 300).unwrap();
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
let max_len = data.len() * 16;
let data_len = data.len();
let len = audio_ring_buffer.lock().unwrap().len();
if len > max_len {
let mut data = audio_ring_buffer.lock().unwrap();
let new_data = data[data_len..].to_vec();
*data = new_data;
}
});
// loop to process the audio data forever (until the user stops the program)
println!("transcribing audio...");
let mut buffered_pcm = vec![];
let mut language_token_set = false;
while let Ok(pcm) = rx.recv() {
use rubato::Resampler;
buffered_pcm.extend_from_slice(&pcm);
if buffered_pcm.len() < 10 * in_sample_rate {
continue;
}
let mut resampled_pcm = vec![];
// resample the audio, one chunk of 1024 samples at a time.
// in case the audio input failed to produce an exact multiple of 1024 samples,
// process the remainder on the next iteration of the loop.
let full_chunks = buffered_pcm.len() / 1024;
let remainder = buffered_pcm.len() % 1024;
for chunk in 0..full_chunks {
let buffered_pcm = &buffered_pcm[chunk * 1024..(chunk + 1) * 1024];
let pcm = resampler.process(&[&buffered_pcm], None)?;
resampled_pcm.extend_from_slice(&pcm[0]);
}
let pcm = resampled_pcm;
println!("{} {}", buffered_pcm.len(), pcm.len());
if remainder == 0 {
buffered_pcm.clear();
} else {
// efficiently copy the remainder to the beginning of the `buffered_pcm` buffer and
// truncate it. That's more efficient then allocating a new vector and copying into it
println!("audio device produced partial chunk with {remainder} samples; processing the remainder on the next iteration of the loop");
buffered_pcm.copy_within(full_chunks * 1024.., 0);
buffered_pcm.truncate(remainder);
}
let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
println!("Transcribing audio...");
for (i, _) in iter::repeat(()).enumerate() {
std::thread::sleep(std::time::Duration::from_millis(1000));
let data = audio_ring_buffer_2.lock().unwrap().clone();
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
@ -654,13 +614,9 @@ pub fn main() -> Result<()> {
)?;
// on the first iteration, we detect the language and set the language token.
if !language_token_set {
if i == 0 {
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
(true, None) => Some(multilingual::detect_language(
decoder.model(),
&tokenizer,
&mel,
)?),
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
@ -671,12 +627,47 @@ pub fn main() -> Result<()> {
}
};
println!("language_token: {:?}", language_token);
decoder.set_language_token(language_token);
language_token_set = true;
dc.set_language_token(language_token);
}
decoder.run(&mel, None)?;
decoder.reset_kv_cache();
dc.run(
&mel,
Some((
i as f64,
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
)),
)?;
dc.reset_kv_cache();
}
Ok(())
}
fn record_audio(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
milliseconds: u64,
) -> Result<Vec<i16>> {
let writer = Arc::new(Mutex::new(Vec::new()));
let writer_2 = writer.clone();
let stream = device.build_input_stream(
&config.config(),
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let processed = data
.iter()
.map(|v| (v * 32768.0) as i16)
.collect::<Vec<i16>>();
writer_2.lock().unwrap().extend_from_slice(&processed);
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
drop(stream);
let data = writer.lock().unwrap().clone();
let step = 3;
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
Ok(data)
}

View File

@ -12,7 +12,7 @@ file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/sample
from the hub.
```bash
cargo run --example whisper --release --features="symphonia"
cargo run --example whisper --release
> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }

View File

@ -370,7 +370,6 @@ enum WhichModel {
Large,
LargeV2,
LargeV3,
LargeV3Turbo,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
@ -389,7 +388,6 @@ impl WhichModel {
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::LargeV3Turbo
| Self::DistilLargeV2
| Self::DistilLargeV3 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
@ -411,7 +409,6 @@ impl WhichModel {
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"),

View File

@ -6,6 +6,7 @@ pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
/// Loads an image from disk using the image crate at the requested resolution,
/// using the given std and mean parameters.
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
p: P,
res: usize,

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.8.1"
version = "0.7.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.1" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.8.1"
version = "0.7.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -82,17 +82,6 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= MIN_CC_DP4A
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
#define MMQ_X_Q4_0_RDNA2 64
#define MMQ_Y_Q4_0_RDNA2 128
#define NWARPS_Q4_0_RDNA2 8
@ -1832,8 +1821,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
sumi = __dp4a(vi0, u[2*i+0], sumi);
sumi = __dp4a(vi1, u[2*i+1], sumi);
}
const float2 ds8f = __half22float2(ds8);
@ -1855,8 +1844,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
sumi = __dp4a(vi0, u[2*i+0], sumi);
sumi = __dp4a(vi1, u[2*i+1], sumi);
}
#ifdef GGML_CUDA_F16
@ -1889,14 +1878,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
}
const float2 ds8f = __half22float2(ds8);
@ -1920,14 +1909,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
}
#ifdef GGML_CUDA_F16
@ -1956,7 +1945,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
sumi = __dp4a(v[i], u[i], sumi);
}
return d8_0*d8_1 * sumi;
@ -1970,7 +1959,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
sumi = __dp4a(v[i], u[i], sumi);
}
#ifdef GGML_CUDA_F16
@ -2005,13 +1994,13 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
const int vi = (v >> (2*i)) & 0x03030303;
sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
// fill int with 4x m
int m = sc >> 4;
m |= m << 8;
m |= m << 16;
sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
}
const float2 dm2f = __half22float2(dm2);
@ -2040,8 +2029,8 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_d_sc = ggml_cuda_dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
sumi_m = ggml_cuda_dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
}
sumi_d += sumi_d_sc * (sc & 0xF);
@ -2082,7 +2071,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
const int vi = __vsubss4(vil, vih);
sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d3 * sumf;
@ -2100,7 +2089,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
int sumi_sc = 0;
for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
}
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@ -2125,8 +2114,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u
const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
@ -2151,7 +2140,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
}
const float2 ds8f = __half22float2(ds8[i]);
@ -2187,8 +2176,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
const int v0i = vl0i | vh0i;
const int v1i = vl1i | vh1i;
const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u
const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]);
@ -2214,7 +2203,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
}
const float2 ds8f = __half22float2(ds8[i]);
@ -2248,7 +2237,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d*sumf;
@ -2267,11 +2256,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
#pragma unroll
for (int i = i0; i < i0 + 2; ++i) {
sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
}
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
@ -2499,10 +2488,10 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const int v1 = q4[0];
const int v2 = q4[4];
const int dot1 = ggml_cuda_dp4a(ui2, v2 & 0x0f0f0f0f, ggml_cuda_dp4a(ui1, v1 & 0x0f0f0f0f, 0));
const int dot2 = ggml_cuda_dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, ggml_cuda_dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
const int dot3 = ggml_cuda_dp4a(0x01010101, ui2, ggml_cuda_dp4a(0x01010101, ui1, 0));
const int dot4 = ggml_cuda_dp4a(0x01010101, ui4, ggml_cuda_dp4a(0x01010101, ui3, 0));
const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
@ -2587,8 +2576,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
const float sumf_d = d8_1 * (ggml_cuda_dp4a(ui1, v1, 0) * s[0] + ggml_cuda_dp4a(ui2, v2, 0) * s[1])
+ d8_2 * (ggml_cuda_dp4a(ui3, v3, 0) * s[2] + ggml_cuda_dp4a(ui4, v4, 0) * s[3]);
const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+ d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
return d * sumf_d;
#endif

View File

@ -70,9 +70,10 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
template <typename T>
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) {
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float2 mean_var = make_float2(0.f, 0.f);
@ -133,9 +134,10 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta,
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) {
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float tmp = 0.0f; // partial sum for thread in warp
@ -528,15 +530,15 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#define RMSNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const int n_cols, const int block_size, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, block_size, eps); \
const int n_cols, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, block_size, eps); \
const TYPENAME *beta, const int n_cols, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.8.1"
version = "0.7.1"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -1,39 +0,0 @@
#include <metal_stdlib>
using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
}
#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \
fill_with<T>(out, value, numel, tid); \
} \
#define FILL_OPS(NAME, T) \
FILL_OP(NAME, T) \
FILL_OPS(u8, uchar)
FILL_OPS(u32, uint)
FILL_OPS(i64, long)
FILL_OPS(f16, half)
FILL_OPS(f32, float)
#if __METAL_VERSION__ >= 310
FILL_OPS(bf16, bfloat)
#endif

View File

@ -17,33 +17,33 @@ METAL_FUNC uint get_strided_index(
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
constant bool &contiguous,
constant size_t *src_dims,
constant size_t *src_strides,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
output[tid] = input[strided_src_i];
}
@ -68,25 +68,25 @@ kernel void NAME( \
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void gather(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
METAL_FUNC void gather(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
}
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
@ -105,27 +105,27 @@ kernel void NAME( \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
METAL_FUNC void scatter_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
@ -145,28 +145,28 @@ kernel void NAME( \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
constant size_t &ids_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
METAL_FUNC void index_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
constant size_t &ids_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
@ -193,16 +193,12 @@ INDEX_OP(is_i64_f16, int64_t, half)
INDEX_OP(is_i64_bf16, int64_t, bfloat)
#endif
INDEX_OP(is_u32_u8, uint32_t, uint8_t)
INDEX_OP(is_u32_u32, uint32_t, uint32_t)
INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
#endif
INDEX_OP(is_u8_u8, uint8_t, uint8_t)
INDEX_OP(is_u8_u32, uint8_t, uint32_t)
INDEX_OP(is_u8_f32, uint8_t, float)
INDEX_OP(is_u8_f16, uint8_t, half)
#if defined(__HAVE_BFLOAT__)
@ -214,12 +210,10 @@ GATHER_OP(gather_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
GATHER_OP(gather_u32_bf16, uint, bfloat)
#endif
GATHER_OP(gather_u32_u32, uint, uint)
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
SCATTER_ADD_OP(sa_u32_u32, uint32_t, uint32_t)
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)

View File

@ -6,15 +6,14 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
pub mod utils;
mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider};
use utils::{get_block_dims, linear_split, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
@ -25,7 +24,6 @@ const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
const TERNARY: &str = include_str!("ternary.metal");
const UNARY: &str = include_str!("unary.metal");
const SDPA: &str = include_str!("scaled_dot_product_attention.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@ -33,7 +31,6 @@ pub enum Source {
Binary,
Cast,
Conv,
Fill,
Gemm,
Indexing,
Mfa,
@ -43,7 +40,6 @@ pub enum Source {
Sort,
Ternary,
Unary,
Sdpa,
}
pub mod copy2d {
@ -161,17 +157,6 @@ pub enum MetalKernelError {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("Sdpa {variation} head size was {got}, expectd {expected:?}")]
SdpaHeadSizeMismatch {
variation: &'static str,
got: usize,
expected: Vec<usize>,
},
#[error("Sdpa {variation} got dtype {got:?}")]
SdpaHeadDTypeMismatch {
variation: &'static str,
got: SdpaDType,
},
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@ -211,7 +196,6 @@ impl Kernels {
Source::Binary => BINARY,
Source::Cast => CAST,
Source::Conv => CONV,
Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
@ -220,7 +204,6 @@ impl Kernels {
Source::Sort => SORT,
Source::Ternary => TERNARY,
Source::Unary => UNARY,
Source::Sdpa => SDPA,
Source::Mfa => panic!("Invalid lib"),
}
}
@ -372,7 +355,7 @@ pub fn call_unary_contiguous_tiled(
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
let tiles = length.div_ceil(tile_size);
let tiles = (length + tile_size - 1) / tile_size;
encoder.set_compute_pipeline_state(&pipeline);
@ -594,7 +577,7 @@ pub fn call_reduce_contiguous(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
(elements_to_sum as u64).div_ceil(2),
(elements_to_sum as u64 + 2 - 1) / 2,
)
.next_power_of_two();
@ -1641,313 +1624,6 @@ pub fn call_gemm(
Ok(())
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum SdpaDType {
BF16,
F16,
F32,
}
/// SDPA full is supported when:
/// - q head dim == 64, 128
/// - no mask
/// - q heads == kv heads
/// - final type != bf16 (TODO maybe just template this kernel too?)
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_full(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_buffer: &Buffer,
v_offset: usize,
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct MLXFastAttentionParams {
m: i32,
n: i32,
k: i32,
ldq: i32, // ldq == ldo
ldk: i32,
ldv: i32,
lds: i32,
ldo: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_q: i32,
batch_stride_k: i32,
batch_stride_v: i32,
batch_stride_o: i32,
swizzle_log: i32,
gemm_n_iterations_aligned: i32,
gemm_k_iterations_aligned: i32,
gemm_sv_m_block_iterations: i32,
batch_ndim: i32,
alpha: f32,
softcapping: f32,
}
let bk = q_shape.last().unwrap();
const BN: usize = 16;
const BM: usize = 16;
const WM: usize = 2;
const WN: usize = 2;
let name = match (bk, itype) {
(32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
(64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
(96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
(128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
(256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
(32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
(64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
(96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
(128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
(256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
(other, SdpaDType::F16 | SdpaDType::F32) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "full",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
(_, SdpaDType::BF16) => {
return Err(MetalKernelError::SdpaHeadDTypeMismatch {
variation: "full",
got: SdpaDType::BF16,
})
}
};
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, seq, hidden)
let qseq = q_shape[q_shape.len() - 2];
let m = q_shape[q_shape.len() - 2];
let n = m;
let k = q_shape[q_shape.len() - 1];
let bs_out = q_shape[0] * q_shape[1];
let batch_shape = [q_shape[0] * q_shape[1]];
let dk = q_shape[q_shape.len() - 1];
let ldq = dk;
let ldk = dk;
let ldv = dk;
let lds = BN;
let ldo = dk;
let tn = 1;
let tm = m.div_ceil(BM);
let b_stride_q = dk * qseq;
let b_stride_k = dk * qseq;
let b_stride_v = dk * qseq;
let b_stride_o = dk * qseq;
let swizzle_log = 0;
let gemm_n_iterations_aligned = n.div_ceil(BN);
let gemm_k_iterations_aligned = k.div_ceil(*bk);
let gemm_sv_m_block_iterations = m.div_ceil(BM);
let batch_ndim = batch_shape.len();
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let params = MLXFastAttentionParams {
m: m as i32,
n: n as i32,
k: k as i32,
ldq: ldq as i32,
ldk: ldk as i32,
ldv: ldv as i32,
lds: lds as i32,
ldo: ldo as i32,
tiles_n: tn,
tiles_m: tm as i32,
batch_stride_q: b_stride_q as i32,
batch_stride_k: b_stride_k as i32,
batch_stride_v: b_stride_v as i32,
batch_stride_o: b_stride_o as i32,
swizzle_log,
gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
batch_ndim: batch_ndim as i32,
alpha,
softcapping,
};
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
impl EncoderParam for MLXFastAttentionParams {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<MLXFastAttentionParams>() as u64,
&data as *const MLXFastAttentionParams as *const c_void,
);
}
}
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params,
&batch_shape[..],
&batch_strides[..]
)
);
let grid_dims = MTLSize {
width: 1,
height: tm as u64,
depth: bs_out as u64,
};
let group_dims = MTLSize {
width: 32,
height: WM as u64,
depth: WN as u64,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
/// SDPA full is supported when:
/// - q head dim == 64, 96, 128
/// - no mask
/// - q,k,v are contiguous
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_vector(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_stride: &[usize],
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
let bk = q_shape.last().unwrap();
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let kstride = k_stride[1];
let vstride = v_stride[1];
let name = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
(32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
(32, SdpaDType::F32) => "sdpa_vector_float_32",
(64, SdpaDType::F32) => "sdpa_vector_float_64",
(96, SdpaDType::F32) => "sdpa_vector_float_96",
(128, SdpaDType::F32) => "sdpa_vector_float_128",
(256, SdpaDType::F32) => "sdpa_vector_float_256",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
};
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, kv_seq, hidden)
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
gqa_factor,
n,
kstride,
vstride,
alpha,
softcapping
)
);
let grid_dims = MTLSize {
width: 1,
height: b as u64,
depth: 1_u64,
};
let group_dims = MTLSize {
width: 1024,
height: 1,
depth: 1,
};
encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col1d_strided(
device: &Device,
@ -2320,7 +1996,7 @@ pub fn call_quantized_matmul_mv_t(
}
fn divide(m: usize, b: usize) -> NSUInteger {
m.div_ceil(b) as NSUInteger
((m + b - 1) / b) as NSUInteger
}
#[allow(clippy::too_many_arguments)]
@ -2681,25 +2357,5 @@ pub fn call_mlx_gemm(
Ok(())
}
pub fn call_const_fill(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (output, v, length));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[cfg(test)]
mod tests;

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
use super::*;
use half::{bf16, f16};
use metal::MTLResourceOptions;
use rand::Rng;
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@ -2308,33 +2307,3 @@ fn conv_transpose1d_u32() {
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}
#[test]
fn const_fill() {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);
test::<i64, _>("fill_i64", |v| v as i64);
test::<f16, _>("fill_f16", f16::from_f32);
test::<bf16, _>("fill_bf16", bf16::from_f32);
test::<f32, _>("fill_f32", |v| v);
}

View File

@ -8,7 +8,7 @@ use std::ffi::c_void;
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = size.div_ceil(width);
let count = (size + width - 1) / width;
let thread_group_count = MTLSize {
width: count,
height: 1,
@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
@ -61,14 +61,18 @@ pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
}
}
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub trait EncoderParam {
pub(crate) trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
@ -128,7 +132,7 @@ impl EncoderParam for (&Buffer, usize) {
}
}
impl EncoderParam for &BufferOffset<'_> {
impl<'a> EncoderParam for &BufferOffset<'a> {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
}
@ -169,7 +173,7 @@ pub struct WrappedEncoder<'a> {
end_encoding_on_drop: bool,
}
impl Drop for WrappedEncoder<'_> {
impl<'a> Drop for WrappedEncoder<'a> {
fn drop(&mut self) {
if self.end_encoding_on_drop {
self.inner.end_encoding()
@ -177,15 +181,14 @@ impl Drop for WrappedEncoder<'_> {
}
}
impl AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'_> {
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
self.inner
}
}
impl EncoderProvider for &metal::CommandBuffer {
type Encoder<'a>
= WrappedEncoder<'a>
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
@ -197,8 +200,7 @@ impl EncoderProvider for &metal::CommandBuffer {
}
impl EncoderProvider for &metal::CommandBufferRef {
type Encoder<'a>
= WrappedEncoder<'a>
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
@ -210,8 +212,7 @@ impl EncoderProvider for &metal::CommandBufferRef {
}
impl EncoderProvider for &ComputeCommandEncoderRef {
type Encoder<'a>
= WrappedEncoder<'a>
type Encoder<'a> = WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {

View File

@ -1,5 +1,3 @@
//! Activation Functions
//!
use candle::{Result, Tensor};
use serde::Deserialize;

View File

@ -9,7 +9,7 @@ pub struct Func<'a> {
f: Arc<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync>,
}
impl std::fmt::Debug for Func<'_> {
impl<'a> std::fmt::Debug for Func<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "func")
}
@ -22,7 +22,7 @@ where
Func { f: Arc::new(f) }
}
impl super::Module for Func<'_> {
impl<'a> super::Module for Func<'a> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
(*self.f)(xs)
}
@ -44,7 +44,7 @@ pub struct FuncT<'a> {
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
}
impl std::fmt::Debug for FuncT<'_> {
impl<'a> std::fmt::Debug for FuncT<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "func")
}
@ -57,7 +57,7 @@ where
FuncT { f: Arc::new(f) }
}
impl super::ModuleT for FuncT<'_> {
impl<'a> super::ModuleT for FuncT<'a> {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
(*self.f)(xs, train)
}

View File

@ -1,5 +1,3 @@
//! Cache Implementations
//!
use candle::{Device, Result, Tensor};
#[derive(Debug, Clone)]

View File

@ -1,20 +1,3 @@
//! candle-nn
//!
//! ## Other Crates
//!
//! Candle consists of a number of crates. This crate holds structs and functions
//! that allow you to build and train neural nets. You may wish
//! to look at the docs for the other crates which can be found here:
//!
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
//!
pub mod activation;
pub mod batch_norm;
pub mod conv;

View File

@ -1,5 +1,3 @@
//! Loss Calculations
//!
use candle::{Result, Tensor};
/// The negative log likelihood loss.

View File

@ -1,6 +1,3 @@
//! Tensor ops.
//!
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
use rayon::prelude::*;
@ -546,23 +543,15 @@ impl candle::CustomOp2 for RmsNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (block_size, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
n_cols as i32,
block_size as i32,
self.eps,
);
let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@ -787,24 +776,15 @@ impl candle::CustomOp3 for LayerNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (block_size, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
&beta,
n_cols as i32,
block_size as i32,
self.eps,
);
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@ -967,193 +947,3 @@ impl Module for Identity {
Ok(xs.clone())
}
}
#[allow(dead_code)]
struct Sdpa {
scale: f32,
softcapping: f32,
}
impl candle::CustomOp3 for Sdpa {
fn name(&self) -> &'static str {
"metal-sdpa"
}
fn cpu_fwd(
&self,
_s1: &CpuStorage,
_l1: &Layout,
_s2: &CpuStorage,
_l2: &Layout,
_s3: &CpuStorage,
_l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
candle::bail!("SDPA has no cpu impl")
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
q: &candle::MetalStorage,
q_l: &Layout,
k: &candle::MetalStorage,
k_l: &Layout,
v: &candle::MetalStorage,
v_l: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle_metal_kernels::SdpaDType;
let device = q.device();
let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
let elem_count: usize = out_dims.iter().product();
let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
// q,k must have matching emb dim
if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
candle::bail!("`q` and `k` last dims must match");
}
// k,v must have matching n kv heads
if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
candle::bail!("`k` and `v` head dims must match");
}
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
}
let k_head = k_l.dim(D::Minus1)?;
let q_head = q_l.dim(D::Minus1)?;
let q_seq = q_l.dim(2)?;
let mut implementation_supports_use_case = q_head == k_head;
let supported_head_dim =
q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
const SDPA_FULL_THRESHOLD: usize = 2;
let supports_sdpa_full =
q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
if !supported_head_dim {
candle::bail!(
"Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
if !implementation_supports_use_case {
candle::bail!(
"Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
q_l.dims(),
k_l.dims(),
v_l.dims()
);
}
for t in [k.dtype(), v.dtype()] {
if q.dtype() != t {
candle::bail!("all q, k, v dtypes must match.");
}
}
let itype = match q.dtype() {
DType::BF16 => SdpaDType::BF16,
DType::F16 => SdpaDType::F16,
DType::F32 => SdpaDType::F32,
other => candle::bail!("unsupported sdpa type {other:?}"),
};
let command_buffer = q.device().command_buffer()?;
if supports_sdpa_vector {
command_buffer.set_label("vector_attention");
candle_metal_kernels::call_sdpa_vector(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k_l.dims(),
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else if supports_sdpa_full {
if q_l.dim(2)? != k_l.dim(2)? {
candle::bail!(
"query and key sequence length must be equal if using full metal sdpa"
)
}
command_buffer.set_label("full_attention");
candle_metal_kernels::call_sdpa_full(
q.device().device(),
&command_buffer,
q.device().kernels(),
q_l.start_offset(),
q_l.dims(),
q.buffer(),
k_l.start_offset(),
k.buffer(),
v_l.start_offset(),
v.buffer(),
&output,
self.scale,
self.softcapping,
itype,
)
.map_err(candle::Error::wrap)?;
} else {
candle::bail!("must be vector or full sdpa kernel");
}
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
Ok((newstorage, Shape::from_dims(&out_dims)))
}
}
/// Scaled dot product attention with a fused kernel.
///
/// Computes softmax(qk^T*scale)v.
///
/// **Inputs shapes:**
/// - `q`: (bs, qhead, seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, hidden)
/// - `k`: (bs, kv_head, kv_seq, v_hidden)
/// - `scale` is applied before softmax.
/// - If `softcapping` != 1.0:
/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
///
/// **Output shape:** (bs, qhead, seq, v_hidden)
///
/// **Supported head dims:** 32, 64, 96, 128, 256.
///
/// ## On Metal:
/// - If `seq` == 1:
/// - Use a vectorized kernel
/// - Supports `seq` != `kv_seq` (cross attn. support)
/// - Supports GQA when `qhead` is a multiple of `kv_head`
/// - Otherwise:
/// - Use an alternate kernel
/// - Requires `seq` == `kv_seq`
/// - GQA is not supported (requires `qhead` == `kv_head`)
pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
}

View File

@ -70,12 +70,6 @@ impl LSTMState {
}
}
#[derive(Debug, Clone, Copy)]
pub enum Direction {
Forward,
Backward,
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy)]
pub struct LSTMConfig {
@ -84,7 +78,6 @@ pub struct LSTMConfig {
pub b_ih_init: Option<super::Init>,
pub b_hh_init: Option<super::Init>,
pub layer_idx: usize,
pub direction: Direction,
}
impl Default for LSTMConfig {
@ -95,7 +88,6 @@ impl Default for LSTMConfig {
b_ih_init: Some(super::Init::Const(0.)),
b_hh_init: Some(super::Init::Const(0.)),
layer_idx: 0,
direction: Direction::Forward,
}
}
}
@ -108,7 +100,6 @@ impl LSTMConfig {
b_ih_init: None,
b_hh_init: None,
layer_idx: 0,
direction: Direction::Forward,
}
}
}
@ -116,7 +107,7 @@ impl LSTMConfig {
/// A Long Short-Term Memory (LSTM) layer.
///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms)]
#[allow(clippy::upper_case_acronyms, unused)]
#[derive(Clone, Debug)]
pub struct LSTM {
w_ih: Tensor,
@ -129,62 +120,6 @@ pub struct LSTM {
dtype: DType,
}
impl LSTM {
/// Creates a LSTM layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let layer_idx = config.layer_idx;
let direction_str = match config.direction {
Direction::Forward => "",
Direction::Backward => "_reverse",
};
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_ih_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(
4 * hidden_dim,
&format!("bias_hh_l{layer_idx}{direction_str}"),
init,
)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &LSTMConfig {
&self.config
}
}
/// Creates a LSTM layer.
pub fn lstm(
in_dim: usize,
@ -192,7 +127,39 @@ pub fn lstm(
config: LSTMConfig,
vb: crate::VarBuilder,
) -> Result<LSTM> {
LSTM::new(in_dim, hidden_dim, config, vb)
let layer_idx = config.layer_idx;
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
&format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
&format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
}
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => {
Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
}
None => None,
};
Ok(LSTM {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
impl RNN for LSTM {
@ -286,7 +253,7 @@ impl GRUConfig {
/// A Gated Recurrent Unit (GRU) layer.
///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms)]
#[allow(clippy::upper_case_acronyms, unused)]
#[derive(Clone, Debug)]
pub struct GRU {
w_ih: Tensor,
@ -299,56 +266,41 @@ pub struct GRU {
dtype: DType,
}
impl GRU {
/// Creates a GRU layer.
pub fn new(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<Self> {
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &GRUConfig {
&self.config
}
}
/// Creates a GRU layer.
pub fn gru(
in_dim: usize,
hidden_dim: usize,
config: GRUConfig,
vb: crate::VarBuilder,
) -> Result<GRU> {
GRU::new(in_dim, hidden_dim, config, vb)
let w_ih = vb.get_with_hints(
(3 * hidden_dim, in_dim),
"weight_ih_l0", // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(3 * hidden_dim, hidden_dim),
"weight_hh_l0", // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
None => None,
};
let b_hh = match config.b_hh_init {
Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
None => None,
};
Ok(GRU {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_dim,
config,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
impl RNN for GRU {

View File

@ -1,5 +1,3 @@
//! Rotary Embeddings
//!
use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
use rayon::prelude::*;

View File

@ -1,5 +1,3 @@
//! Sequential Layer
//!
//! A sequential layer used to chain multiple layers and closures.
use candle::{Module, Result, Tensor};

Some files were not shown because too many files have changed in this diff Show More