Compare commits

...

30 Commits
0.5.1 ... 0.6.0

Author SHA1 Message Date
a3dd87f15e Adding Gemm and ArgMax operators to candle-onnx (#2231)
* feat(gemm): implement Gemm operator in candle-onnx

* feat(onnx): Add support for ArgMax operator in candle-onnx

* Apply rustfmt.

* Remove argmax as it was already present.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-28 21:40:31 +02:00
242e006bbb Depth Anything v2 (#2279)
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-24 19:12:52 +02:00
6baa1d486b Fix a bug in the metal implemtation of col2im1d. (#2284) 2024-06-22 23:21:20 +02:00
36cf54525d Fix the fast bf16 gemm cublas kernels. (#2274)
* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
2024-06-18 23:46:58 +02:00
2b10aaa05d implement Slice op (#2260) 2024-06-12 07:15:32 +01:00
9f804af29d feat(ci): add trufflehog secrets detection (#2262)
* feat(ci): add trufflehog secrets detection

* fix(ci): remove unnecessary permissions
2024-06-10 21:03:54 +01:00
54ff971e35 Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models.

* Add more models.
2024-06-07 10:51:50 +01:00
b9fac7ec00 implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode

The intent of this change is to allow eval of the current silero_vad.onnx (v4).
This onnx file uses 'If' and 'Pad' nodes, which had not been supported
by simple_eval until now

* Cleanup (fmt, clippy, minor test tweaks).

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-06 22:36:23 +02:00
f65e90e7ef Bump the crate version. (#2248) 2024-06-05 15:49:15 +02:00
d39462856b Apply rustfmt. (#2247) 2024-06-04 22:54:09 +02:00
cb180eb23a ONNX: add ArgMin, ArgMax and LeakyRelu (#2246)
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

* Added ArgMin operator implementation

* Added tests for ArgMin

* ArgMin now returns a tensor with i64

* Added tests from pytorch examples

* Added ArgMax operator implementation

* Added tests for ArgMax

* Added LeakyRelu implementation

* Added a test for LeakyRelu

* Typo fix

* Fix a weird automatic RustRover change

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
2024-06-04 22:49:02 +02:00
9182c828e6 Automatically upcast for to_u64 (#2244) 2024-06-04 11:32:36 +02:00
3f13ad3d79 Fix dataset id for MNIST (#2238) 2024-06-04 06:27:24 +02:00
cd4d941ed1 Add LLaVA support (#2234)
* first commit

* llava

* clippy and fmt

* some fixes

* minor fixes

* remove useless file

* refactor: Remove llava/constants.rs and update llava/mod.rs

* modify variable name

* modify code after clippy

* Minor tweaks.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-03 11:54:09 +02:00
03344d3c19 ONNX: Add Floor and Ceil (#2235) 2024-06-02 21:45:20 +02:00
1ec3b2cc18 add where_cond f32 for metal (#2236) 2024-06-02 14:30:06 +02:00
f7773d498a Deactivate some book test that breaks the CI. (#2233)
* Deactivate some book test that breaks the CI.

* Clippy fix.
2024-06-01 09:44:22 +02:00
7abc3b8cd7 Bump cudarc version to 0.11.4 (#2230) 2024-06-01 08:18:35 +02:00
46012ed31f Another cudarc update. (#2229) 2024-05-30 22:27:06 +02:00
f3fade3b03 Update cudarc to 0.11.2. (#2227) 2024-05-29 18:50:52 +02:00
ea260aeffd Add Debug, Clone, Deserialize to moondream config (#2222) 2024-05-28 06:08:00 +02:00
0814dfd148 Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d.

* Enable the col2im variant.

* Bugfix.

* Revert the quantized tweak.
2024-05-25 11:03:23 +02:00
3ceca9901a Enable the new layer-norm. (#2213)
* Enable the new layer-norm.

* Shape fixes.
2024-05-24 16:48:21 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
6f0b807ffd More efficient cuda implementation for ConvTranspose1d. (#2211)
* More efficient cuda implementation for ConvTranspose1d.

* Small tweak.
2024-05-24 11:05:43 +02:00
d54e02d73d Avoid a contiguous call in the quantized phi 3 model. (#2209)
* Simplify the KvCache api.

* Avoid a contiguous call in the quantized phi3 model.
2024-05-23 21:24:55 +02:00
45e235a747 Simplify the KvCache api. (#2207) 2024-05-23 17:07:21 +02:00
31cf64147b Add a couple kv-cache helper functions. (#2206) 2024-05-23 16:21:47 +02:00
77ea479a18 Add Phi-3 Medium (#2205) 2024-05-23 13:33:17 +02:00
72e7ca529a Add some missing where-cond kernels for metal. (#2203) 2024-05-22 09:44:52 +02:00
56 changed files with 4802 additions and 223 deletions

15
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,15 @@
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.5.1"
version = "0.6.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" }
candle-datasets = { path = "./candle-datasets", version = "0.5.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" }
candle-kernels = { path = "./candle-kernels", version = "0.5.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" }
candle-nn = { path = "./candle-nn", version = "0.5.1" }
candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
candle-nn = { path = "./candle-nn", version = "0.6.0" }
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"

View File

@ -106,8 +106,8 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
}
}
#[allow(unused)]
#[rustfmt::skip]
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use hf_hub::{api::sync::Api, Repo, RepoType};

View File

@ -9,8 +9,10 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
@ -19,6 +21,7 @@ fn main() -> Result<()> {
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();

View File

@ -10,7 +10,7 @@ pub use utils::{
};
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true;
const USE_COL2IM_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -121,7 +121,8 @@ impl ReduceIndex {
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
let dst_to_set = dst.spare_capacity_mut();
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
let dst_to_set =
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
@ -2249,7 +2250,7 @@ impl BackendStorage for CpuStorage {
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
if USE_IM2COL_CONV1D_TR && can_use_col2im {
if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {

View File

@ -174,7 +174,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
@ -185,7 +187,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec(
@ -224,7 +228,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec(
@ -311,7 +317,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) };
@ -333,7 +341,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];

View File

@ -16,7 +16,7 @@ mod error;
mod utils;
pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
pub enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>),
@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
col: &CudaSlice<T>,
dev: &CudaDevice,
l: &Layout,
) -> Result<CudaSlice<T>> {
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(im)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let device = self.device().clone();
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
let can_use_col2im = kernel_l.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!(
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
)
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col.slice, &device, &col_l)?
} else {
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
};
Ok(Self { slice, device })
}
@ -1964,15 +2035,13 @@ unsafe fn gemm_strided_batched_bf16(
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32();
let alpha = f16::from_f32(alpha_f32);
let beta = f16::from_f32(beta_f32);
// The type for alpha and beta depends on the computeType.
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
(
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
(&alpha) as *const f16 as *const _,
(&beta) as *const f16 as *const _,
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
(&alpha_f32) as *const f32 as *const _,
(&beta_f32) as *const f32 as *const _,
)
} else {
(

View File

@ -54,6 +54,44 @@ pub trait Map2 {
}
}
pub trait Map3 {
#[allow(clippy::too_many_arguments)]
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
src3: &CudaSlice<T>,
layout3: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
#[allow(clippy::too_many_arguments)]
fn map(
&self,
s1: &S,
l1: &Layout,
s2: &S,
l2: &Layout,
s3: &S,
l3: &Layout,
d: &CudaDevice,
) -> Result<S> {
let out = match (s1, s2, s3) {
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,

View File

@ -718,6 +718,7 @@ impl BackendStorage for MetalStorage {
}
let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32",
(DType::U32, DType::F32) => "where_u32_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",
@ -824,44 +825,107 @@ impl BackendStorage for MetalStorage {
k_layout: &Layout,
params: &ParamsConvTranspose1D,
) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let can_use_col2im = k_layout.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = layout.shape().dims3()?;
let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
layout.shape(),
k_layout.shape()
)
}
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let name = match self.dtype {
DType::F32 => "col2im1d_f32",
DType::U32 => "col2im1d_u32",
DType::U8 => "col2im1d_u8",
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
};
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
k_layout.start_offset(),
);
self.matmul(
k,
(b_size, l_in, c_out * k_size, c_in),
&layout.transpose(1, 2)?,
&kernel_l_mm,
)?
};
// It is important for the command buffer to be obtained *after* the matmul
// kernel has run, otherwise we might use a command-buffer that has been commited
// already resulting in the following error.
// _status < MTLCommandBufferStatusCommitted >
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_col2im1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
&[b_size, l_in, c_out, k_size],
params.k_size,
params.stride,
BufferOffset::zero_offset(&col.buffer),
&buffer,
)
.map_err(MetalError::from)?;
buffer
} else {
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
buffer
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
}

View File

@ -217,10 +217,16 @@ impl Value {
}
}
/// This will also automatically upcast any integral types which will not truncate.
pub fn to_u64(&self) -> Result<u64> {
match self {
Self::U64(v) => Ok(*v),
v => crate::bail!("not a u64 {v:?}"),
// Autoupcast cases here
Self::U8(v) => Ok(*v as u64),
Self::U16(v) => Ok(*v as u64),
Self::U32(v) => Ok(*v as u64),
Self::Bool(v) => Ok(*v as u64),
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
}
}

View File

@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
pub fn load() -> Result<crate::vision::Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "mnist".to_string();
let dataset_id = "ylecun/mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,

View File

@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]]
name = "llama_multiprocess"
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["encodec"]
[[example]]
name = "depth_anything_v2"
required-features = ["depth_anything_v2"]

View File

@ -0,0 +1,13 @@
# candle-dinov2
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
## Running an example with color map and CUDA
```bash
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -0,0 +1,50 @@
use enterpolation::linear::ConstEquidistantLinear;
use enterpolation::Generator;
use palette::LinSrgb;
use candle::Tensor;
pub struct SpectralRColormap {
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
}
impl SpectralRColormap {
pub(crate) fn new() -> Self {
// Define a colormap similar to 'Spectral_r' by specifying key colors.
// got the colors from ChatGPT-4o
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
]);
Self { gradient }
}
fn get_color(&self, value: f32) -> LinSrgb {
self.gradient.gen(value)
}
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
println!("Gray: {:?}", gray.dims());
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
let rgb_values: Vec<f32> = gray_values
.iter()
.map(|g| self.get_color(*g))
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
.collect();
let [.., height, width] = gray.dims() else {
candle::bail!("Not enough dims!")
};
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
color.permute((2, 0, 1))
}
}

View File

@ -0,0 +1,187 @@
//! Depth Anything V2
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use std::ffi::OsString;
use std::path::PathBuf;
use clap::Parser;
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
use candle_examples::{load_image, load_image_and_resize, save_image};
use candle_nn::VarBuilder;
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
use candle_transformers::models::dinov2;
use crate::color_map::SpectralRColormap;
mod color_map;
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
const DINO_IMG_SIZE: usize = 518;
#[derive(Parser)]
struct Args {
#[arg(long)]
dinov2_model: Option<PathBuf>,
#[arg(long)]
depth_anything_v2_model: Option<PathBuf>,
#[arg(long)]
image: PathBuf,
#[arg(long)]
output_dir: Option<PathBuf>,
#[arg(long)]
cpu: bool,
#[arg(long)]
color_map: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let dinov2_model_file = match args.dinov2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-dino-v2".into());
api.get("dinov2_vits14.safetensors")?
}
Some(dinov2_model) => dinov2_model,
};
println!("Using file {:?}", dinov2_model_file);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
let dinov2 = dinov2::vit_small(vb)?;
println!("DinoV2 model built");
let depth_anything_model_file = match args.depth_anything_v2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
api.get("depth_anything_v2_vits.safetensors")?
}
Some(depth_anything_model) => depth_anything_model,
};
println!("Using file {:?}", depth_anything_model_file);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
};
let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
println!("Loaded image {image:?}");
let depth = depth_anything.forward(&image)?;
println!("Got predictions {:?}", depth.shape());
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
let output_path = full_output_path(&args.image, &args.output_dir);
println!("Saving image to {}", output_path.to_string_lossy());
save_image(&output_image, output_path)?;
Ok(())
}
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
let input_file_name = image_path.file_name().unwrap();
let mut output_file_name = OsString::from("depth_");
output_file_name.push(input_file_name);
let mut output_path = match output_dir {
None => image_path.parent().unwrap().to_path_buf(),
Some(output_path) => output_path.clone(),
};
output_path.push(output_file_name);
output_path
}
fn load_and_prep_image(
image_path: &PathBuf,
device: &Device,
) -> anyhow::Result<(usize, usize, Tensor)> {
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
.unsqueeze(0)?
.to_dtype(F32)?
.to_device(&device)?;
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(&device)?
.broadcast_as(image.shape())?;
let image = (image / max_pixel_val)?;
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
Ok((original_height, original_width, image))
}
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
let mean_tensor =
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
let std_tensor =
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
image.sub(&mean_tensor)?.div(&std_tensor)
}
fn post_process_image(
image: &Tensor,
original_height: usize,
original_width: usize,
color_map: bool,
) -> Result<Tensor> {
let out = image.interpolate2d(original_height, original_width)?;
let out = scale_image(&out)?;
let out = if color_map {
let spectral_r = SpectralRColormap::new();
spectral_r.gray2color(&out)?
} else {
let rgb_slice = [&out, &out, &out];
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
};
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(out.device())?
.broadcast_as(out.shape())?;
let out = (out * max_pixel_val)?;
out.to_dtype(U8)
}
fn scale_image(depth: &Tensor) -> Result<Tensor> {
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
let min_val_tensor = Tensor::try_from(*min_val)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
let depth = (depth - min_val_tensor)?;
let range = max_val - min_val;
let range_tensor = Tensor::try_from(range)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
depth / range_tensor
}

View File

@ -0,0 +1,4 @@
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";

View File

@ -0,0 +1,114 @@
pub enum SeparatorStyle {
Two,
Mpt,
}
pub struct Conversation {
pub system: String,
pub roles: Vec<String>,
pub messages: Vec<(String, Option<String>)>,
pub offset: i32,
pub sep_style: SeparatorStyle,
pub sep: String,
pub sep2: Option<String>,
pub version: String,
}
impl Conversation {
pub fn new(
system: &str,
roles: &[String],
offset: i32,
sep_style: SeparatorStyle,
sep: &str,
sep2: Option<&str>,
version: &str,
) -> Self {
Conversation {
system: system.to_string(),
roles: roles.to_vec(),
messages: Vec::new(),
offset,
sep_style,
sep: sep.to_string(),
sep2: sep2.map(|s| s.to_string()),
version: version.to_string(),
}
}
pub fn conv_chatml_direct() -> Self {
Conversation::new(
"<|im_start|>system\nAnswer the questions.",
&[
"<|im_start|>user\n".to_string(),
"<|im_start|>assistant\n".to_string(),
],
0,
SeparatorStyle::Mpt,
"<|im_end|>",
None,
"mpt",
)
}
pub fn conv_llava_v1() -> Self {
Conversation::new(
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
&[
"USER".to_string(),
"ASSISTANT".to_string(),
],
0,
SeparatorStyle::Two,
" ",
Some("</s>"),
"v1"
)
}
pub fn append_message(&mut self, role: String, message: Option<&str>) {
self.messages.push((role, message.map(|s| s.to_string())))
}
pub fn append_user_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[0].clone(), message);
}
pub fn append_assistant_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[1].clone(), message);
}
pub fn get_prompt(&self) -> String {
match self.sep_style {
SeparatorStyle::Mpt => {
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&self.sep);
for (role, message) in &self.messages {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(message);
};
ret.push_str(&self.sep);
}
ret
}
SeparatorStyle::Two => {
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&seps[0]);
for (i, (role, message)) in self.messages.iter().enumerate() {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
ret.push_str(message);
ret.push_str(&seps[i % 2]);
} else {
ret.push(':')
}
}
ret
}
}
}
}

View File

@ -0,0 +1,317 @@
use std::cmp::min;
use candle::{bail, DType, Device, Result, Tensor};
use candle_transformers::models::llava::{
config::{HFPreProcessorConfig, LLaVAConfig},
utils::select_best_resolution,
};
use hf_hub::api::sync::Api;
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
use serde::{Deserialize, Serialize};
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
#[derive(Serialize, Deserialize, Debug)]
pub struct ImageProcessor {
#[serde(default = "default_size")]
pub size: u32, // this is not the same as python transformer
#[serde(default = "default_do_resize")]
pub do_resize: bool,
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
#[serde(default = "default_do_center_crop")]
pub do_center_crop: bool,
#[serde(default = "default_crop_size")]
pub crop_size: u32, // this is not the same as python transformer
#[serde(default = "default_do_rescale")]
pub do_rescale: bool,
#[serde(default = "default_rescale_factor")]
pub rescale_factor: f32,
#[serde(default = "default_do_normalize")]
pub do_normalize: bool,
#[serde(default = "default_image_mean")]
pub image_mean: Vec<f32>,
#[serde(default = "default_image_std")]
pub image_std: Vec<f32>,
}
fn default_size() -> u32 {
224
}
fn default_do_resize() -> bool {
true
}
fn default_do_center_crop() -> bool {
true
}
fn default_crop_size() -> u32 {
224
}
fn default_do_rescale() -> bool {
true
}
fn default_rescale_factor() -> f32 {
1.0 / 255.0
}
fn default_do_normalize() -> bool {
true
}
fn default_image_mean() -> Vec<f32> {
vec![0.48145466, 0.4578275, 0.40821073]
}
fn default_image_std() -> Vec<f32> {
vec![0.26862954, 0.2613026, 0.2757771]
}
impl ImageProcessor {
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
let api = api.model(clip_id.to_string());
let config_filename = api
.get("preprocessor_config.json")
.map_err(|e| candle::Error::Msg(e.to_string()))?;
let image_processor =
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
.map_err(|e| candle::Error::Msg(e.to_string()))?;
Ok(image_processor)
}
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
Self {
size: hf_preprocessor_config.size["shortest_edge"] as u32,
do_resize: hf_preprocessor_config.do_resize,
do_center_crop: hf_preprocessor_config.do_center_crop,
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
do_rescale: hf_preprocessor_config.do_rescale,
rescale_factor: hf_preprocessor_config.rescale_factor,
do_normalize: hf_preprocessor_config.do_normalize,
image_mean: hf_preprocessor_config.image_mean.clone(),
image_std: hf_preprocessor_config.image_std.clone(),
}
}
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let size = self.size;
if width == size && height == size {
image.clone()
} else {
let (new_width, new_height) = if width < height {
(
size,
(((size * height) as f32) / width as f32).ceil() as u32,
)
} else {
(
(((size * width) as f32) / height as f32).ceil() as u32,
size,
)
};
image.resize(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
)
}
}
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let crop_size = self.crop_size;
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
image.crop_imm(left, top, crop_size, crop_size)
}
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
let img = image.to_rgb8().into_raw();
let (width, height) = image.dimensions();
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
.to_dtype(DType::F32) // only for internal compute
}
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
let rescale_factor = self.rescale_factor as f64;
tensor.affine(rescale_factor, 0.0)
}
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
let image_mean = self.image_mean.clone();
let image_std = self.image_std.clone();
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
}
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
tensor.permute((2, 0, 1))
}
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
let image = if self.do_resize {
self.resize(image)
} else {
image.clone()
};
let image = if self.do_center_crop {
self.center_crop(&image)
} else {
image
};
let tensor = self.to_tensor(&image)?;
let tensor = if self.do_rescale {
self.rescale(&tensor)?
} else {
tensor
};
let tensor = if self.do_normalize {
self.normalize(&tensor)?
} else {
tensor
};
self.to_channel_dimension_format(&tensor)
}
}
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
let (width, height) = image_size;
let (center_width, center_height) = center_size;
let left = if width <= center_width {
0
} else {
((width as f32 - center_width as f32) / 2.0).ceil() as u32
};
let top = if height <= center_height {
0
} else {
((height as f32 - center_height as f32) / 2.0).ceil() as u32
};
(left, top)
}
pub fn process_image(
image: &DynamicImage,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
) -> candle::Result<Tensor> {
if llava_config.image_aspect_ratio == *"square" {
processor.preprocess(image)?.unsqueeze(0)
} else if llava_config.image_aspect_ratio == *"anyres" {
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
} else if llava_config.image_aspect_ratio == *"pad" {
process_pad_image(image, processor)
} else {
bail!("Invalid image aspect ratio")
}
}
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
let mean_color = processor
.image_mean
.iter()
.map(|x| ((*x) * 255.0) as u8)
.collect::<Vec<u8>>();
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
let image_padded = expand2square(image, mean_color);
processor.preprocess(&image_padded)
}
fn process_anyres_image(
image: &DynamicImage,
processor: &ImageProcessor,
grid_pinpoints: &[(u32, u32)],
) -> Result<Tensor> {
let original_size = image.dimensions();
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
let image_padded = resize_and_pad_image(image, best_resolution);
let image_original_resize = image.resize_exact(
processor.size,
processor.size,
image::imageops::FilterType::CatmullRom,
);
let mut patches = vec![image_original_resize];
for patch in divide_to_patches(&image_padded, processor.crop_size) {
patches.push(patch);
}
let tensors = patches
.iter()
.map(|patch| processor.preprocess(patch))
.collect::<Result<Vec<Tensor>>>()?;
Tensor::stack(&tensors, 0)
}
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
let (width, height) = image.dimensions();
match width.cmp(&height) {
std::cmp::Ordering::Less => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
new_image
}
std::cmp::Ordering::Equal => image.clone(),
std::cmp::Ordering::Greater => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
new_image
}
}
}
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
let (original_width, original_height) = image.dimensions();
let original_width_f = original_width as f32;
let original_height_f = original_height as f32;
let (target_width, target_height) = target_resolution;
let target_width_f = target_width as f32;
let target_height_f = target_height as f32;
let scale_w = target_width_f / original_width_f;
let scale_h = target_height_f / original_height_f;
let (new_width, new_height) = if scale_w < scale_h {
(
target_width,
min((original_height_f * scale_w).ceil() as u32, target_height),
)
} else {
(
min((original_width_f * scale_h).ceil() as u32, target_width),
target_height,
)
};
let resized_image = image.resize_exact(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
);
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
let (paste_x, paste_y) =
calculate_middle((target_width, target_height), (new_width, new_height));
overlay(
&mut new_image,
&resized_image,
paste_x.into(),
paste_y.into(),
);
new_image
}
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
let (width, height) = image.dimensions();
let mut patches = Vec::new();
for y in (0..height).step_by(patch_size as usize) {
for x in (0..width).step_by(patch_size as usize) {
let patch = image.crop_imm(x, y, patch_size, patch_size);
patches.push(patch);
}
}
patches
}

View File

@ -0,0 +1,316 @@
pub mod constants;
pub mod conversation;
pub mod image_processor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::Cache;
use anyhow::{bail, Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llava::config::{
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
};
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
use clap::Parser;
use constants::*;
use conversation::Conversation;
use hf_hub::api::sync::Api;
use image_processor::{process_image, ImageProcessor};
use std::io::Write;
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about,long_about=None)]
struct Args {
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
model_path: String,
#[arg(long, default_value = "tokenizer/tokenizer.json")]
tokenizer_path: String,
#[arg(long)]
model_base: Option<String>,
#[arg(long)]
image_file: String, // Required
#[arg(long)]
conv_mode: Option<String>,
#[arg(long, default_value_t = 0.2)]
temperature: f32,
#[arg(long, default_value_t = 512)]
max_new_tokens: usize,
#[arg(long, action)]
hf: bool,
#[arg(long, action)]
cpu: bool,
#[arg(long, action)]
no_kv_cache: bool,
#[arg(long)]
prompt: String,
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
#[arg(long, default_value_t = 299792458)]
seed: u64,
}
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
fn load_image<T: AsRef<std::path::Path>>(
path: T,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
dtype: DType,
) -> Result<((u32, u32), Tensor)> {
let img = image::io::Reader::open(path)?.decode()?;
let img_tensor = process_image(&img, processor, llava_config)?;
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
}
fn get_model_name_from_path(model_path: &str) -> String {
let model_paths: Vec<String> = model_path
.trim_matches('/')
.split('/')
.map(|s| s.to_string())
.collect();
if model_paths.last().unwrap().starts_with("checkpoint-") {
format!(
"{}_{}",
model_paths[model_paths.len() - 2],
model_paths.last().unwrap()
)
} else {
model_paths.last().unwrap().to_string()
}
}
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
where
T: Clone,
{
let mut res = Vec::new();
for _ in 0..n {
res.extend(vec.to_owned());
}
res
}
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
where
T: Clone,
{
let sep = vec![sep];
let sep = duplicate_vec(&sep, x.len());
let mut res = x
.iter()
.zip(sep.iter())
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
.collect::<Vec<Vec<T>>>();
res.pop();
res
}
fn tokenizer_image_token(
prompt: &str,
tokenizer: &Tokenizer,
image_token_index: i64,
llava_config: &LLaVAConfig,
) -> Result<Tensor> {
let prompt_chunks = prompt
.split("<image>")
.map(|s| {
tokenizer
.encode(s, true)
.unwrap()
.get_ids()
.to_vec()
.iter()
.map(|x| *x as i64)
.collect()
})
.collect::<Vec<Vec<i64>>>();
let mut input_ids = Vec::new();
let mut offset = 0;
if !prompt_chunks.is_empty()
&& !prompt_chunks[0].is_empty()
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
{
offset = 1;
input_ids.push(prompt_chunks[0][0]);
}
for x in insert_separator(
prompt_chunks,
duplicate_vec(&[image_token_index], offset + 1),
)
.iter()
{
input_ids.extend(x[1..].to_vec())
}
let input_len = input_ids.len();
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
}
fn main() -> Result<()> {
let mut args = Args::parse();
let device = candle_examples::device(args.cpu)?;
println!("Start loading model");
let api = Api::new()?;
let api = api.model(args.model_path.clone());
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
let config_filename = api.get("config.json")?;
let hf_llava_config: HFLLaVAConfig =
serde_json::from_slice(&std::fs::read(config_filename)?)?;
let generation_config_filename = api.get("generation_config.json")?;
let generation_config: HFGenerationConfig =
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
let preprocessor_config: HFPreProcessorConfig =
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
let llava_config =
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
let tokenizer_filename = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let clip_vision_config = hf_llava_config.to_clip_vision_config();
(
llava_config,
tokenizer,
Some(clip_vision_config),
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
)
} else {
let config_filename = api.get("config.json")?;
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
(
llava_config.clone(),
tokenizer,
None,
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
)
};
let llama_config = llava_config.to_llama_config();
let dtype: DType = match llava_config.torch_dtype.as_str() {
"float16" => DType::F16,
"bfloat16" => DType::BF16,
_ => bail!("unsupported dtype"),
};
let eos_token_id = llava_config.eos_token_id;
println!("setting kv cache");
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
println!("loading model weights");
let weight_filenames =
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
println!("generating conv template");
let image_token_se = format!(
"{}{}{}",
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
);
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
if llava_config.mm_use_im_start_end {
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
} else {
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
}
} else if llava_config.mm_use_im_start_end {
format!("{}\n{}", image_token_se, args.prompt)
} else {
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
};
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
let conv_mode = if model_name.contains("llama-2") {
"llava_llama_2"
} else if model_name.contains("mistral") {
"mistral_instruct"
} else if model_name.contains("v1.6-34b") {
"chatml_direct"
} else if model_name.contains("v1") {
"llava_v1"
} else if model_name.contains("mpt") {
"mpt"
} else {
"llava_v0"
};
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
println!(
"Warning: the model is trained with {}, but you are using {}",
conv_mode,
args.conv_mode.as_deref().unwrap()
);
} else {
args.conv_mode = Some(conv_mode.to_string());
}
let mut conv = match args.conv_mode {
Some(conv_mode) => match conv_mode.as_str() {
"chatml_direct" => Conversation::conv_chatml_direct(),
"llava_v1" => Conversation::conv_llava_v1(),
_ => todo!("not implement yet"),
},
None => bail!("conv_mode is required"),
};
conv.append_user_message(Some(&qs));
conv.append_assistant_message(None);
let prompt = conv.get_prompt();
println!("loading image");
let (image_size, image_tensor) =
load_image(&args.image_file, &image_processor, &llava_config, dtype)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
let image_tensor = image_tensor.to_device(&device)?;
let mut logits_processor = {
let temperature = f64::from(args.temperature);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
Sampling::All { temperature }
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
// get input tokens
let tokens = tokenizer_image_token(
&prompt,
&tokenizer,
llava_config.image_token_index as i64,
&llava_config,
)?;
let mut input_embeds =
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let mut index_pos = 0;
for index in 0..args.max_new_tokens {
let (_, input_embeds_len, _) = input_embeds.dims3()?;
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(input_embeds_len, 0)
};
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
let logits = logits.squeeze(0)?;
let (_, input_len, _) = input.dims3()?;
index_pos += input_len;
let next_token = logits_processor.sample(&logits)?;
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
if next_token == eos_token_id as u32 {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
Ok(())
}

View File

@ -0,0 +1,40 @@
# candle-llava
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
## model zoo
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
* [llava-hf](https://huggingface.co/llava-hf)
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
## Tokenizer Setup
The llava-hf models contain a `tokenizer.json` file so can be used directly with
the `-hf` command line flag.
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
```bash
conda create -n llava python=3.10
pip install transformers protobuf
conda activate llava
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
```
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
## eval
```bash
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
```
## Major Limitations
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
2. There are some ops like split, nonzero and where are not supported by candle.
3. Lack of quantization and LoRA support.

View File

@ -141,6 +141,8 @@ enum WhichModel {
V2,
#[value(name = "3")]
V3,
#[value(name = "3-medium")]
V3Medium,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
@ -254,6 +256,7 @@ fn main() -> Result<()> {
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@ -273,6 +276,7 @@ fn main() -> Result<()> {
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2
| WhichModel::V3
| WhichModel::V3Medium
| WhichModel::PuffinPhiV2
| WhichModel::PhiHermes => "main".to_string(),
}
@ -287,7 +291,8 @@ fn main() -> Result<()> {
| WhichModel::V1_5
| WhichModel::V2
| WhichModel::V2Old
| WhichModel::V3 => repo.get("tokenizer.json")?,
| WhichModel::V3
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@ -303,14 +308,14 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
WhichModel::V3 => anyhow::bail!(
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
"use the quantized or quantized-phi examples for quantized phi-v3"
),
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
@ -332,7 +337,7 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
WhichModel::V3 => {
WhichModel::V3 | WhichModel::V3Medium => {
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
}
};
@ -352,7 +357,9 @@ fn main() -> Result<()> {
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
if args.model == WhichModel::V3 && device.is_cuda() {
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
&& device.is_cuda()
{
DType::BF16
} else {
DType::F32
@ -368,7 +375,7 @@ fn main() -> Result<()> {
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
WhichModel::V3 => {
WhichModel::V3 | WhichModel::V3Medium => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: Phi3Config = serde_json::from_str(&config)?;

View File

@ -217,7 +217,6 @@ fn main() -> anyhow::Result<()> {
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
1,
args.use_flash_attn,
model,
&mut file,

View File

@ -144,6 +144,14 @@ enum WhichModel {
W72b,
#[value(name = "moe-a2.7b")]
MoeA27b,
#[value(name = "2-0.5b")]
W2_0_5b,
#[value(name = "2-1.5b")]
W2_1_5b,
#[value(name = "2-7b")]
W2_7b,
#[value(name = "2-72b")]
W2_72b,
}
#[derive(Parser, Debug)]
@ -234,16 +242,20 @@ fn main() -> Result<()> {
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let size = match args.model {
WhichModel::W0_5b => "0.5B",
WhichModel::W1_8b => "1.8B",
WhichModel::W4b => "4B",
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
WhichModel::MoeA27b => "MoE-A2.7B",
let (version, size) = match args.model {
WhichModel::W2_0_5b => ("2", "0.5B"),
WhichModel::W2_1_5b => ("2", "1.5B"),
WhichModel::W2_7b => ("2", "7B"),
WhichModel::W2_72b => ("2", "72B"),
WhichModel::W0_5b => ("1.5", "0.5B"),
WhichModel::W1_8b => ("1.5", "1.8B"),
WhichModel::W4b => ("1.5", "4B"),
WhichModel::W7b => ("1.5", "7B"),
WhichModel::W14b => ("1.5", "14B"),
WhichModel::W72b => ("1.5", "72B"),
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
};
format!("Qwen/Qwen1.5-{size}")
format!("Qwen/Qwen{version}-{size}")
}
};
let repo = api.repo(Repo::with_revision(
@ -261,11 +273,15 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
vec![repo.get("model.safetensors")?]
}
WhichModel::W4b
| WhichModel::W7b
| WhichModel::W2_7b
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::W2_72b
| WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.5.1"
version = "0.6.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.1" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.5.1"
version = "0.6.0"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -97,6 +97,50 @@ __device__ void im2col1d(
}
}
template <typename T>
__device__ void col2im1d(
const size_t dst_el,
const size_t l_out,
const size_t l_in,
const size_t c_out,
const size_t k_size,
const size_t stride,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, l_in, c_out, l_k)
// dst: (b_size, c_out, l_out)
if (dst_i >= dst_el) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t c_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= c_idx * dst_s1;
const int l_out_idx = tmp_dst_i;
dst[dst_i] = static_cast<T>(0);
int l_in_idx = l_out_idx / stride;
int k0 = l_out_idx - l_in_idx * stride;
// l_out_idx = l_in_idx * stride + k0
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
if (l_in_idx < l_in) {
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
dst[dst_i] += src[src_i];
}
}
}
template <typename T>
__device__ void im2col(
const size_t dst_numel,
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
} \
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_el, \
const size_t l_out, \
const size_t l_in, \
const size_t c_out, \
const size_t k_size, \
const size_t stride, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
} \
#define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
#endif
#if __CUDA_ARCH__ >= 530
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16)
COL2IM1D_OP(__half, col2im1d_f16)
#endif
CONV1D_OP(float, float, conv1d_f32)
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(double, im2col1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
COL2IM1D_OP(float, col2im1d_f32)
COL2IM1D_OP(double, col2im1d_f64)
COL2IM1D_OP(uint8_t, col2im1d_u8)
COL2IM1D_OP(uint32_t, col2im1d_u32)

View File

@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0];
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
return x;
}
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
template <typename T>
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
if (alpha == nullptr && beta == nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs);
}
}
else if (alpha == nullptr && beta != nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float b = static_cast<float>(beta[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs + b);
}
}
else if (alpha != nullptr && beta == nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float a = static_cast<float>(alpha[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs * a);
}
}
else {
for (int col = tid; col < ncols; col += block_size) {
float a = static_cast<float>(alpha[col]);
float b = static_cast<float>(beta[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
}
}
}
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const TYPENAME *beta, const int n_cols, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
LAYERNORM_OP(float, layernorm_f32)
LAYERNORM_OP(double, layernorm_f64)
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.5.1"
version = "0.6.0"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -68,6 +68,50 @@ METAL_FUNC void im2col(
}
}
template <typename T>
METAL_FUNC void col2im1d(
constant size_t &dst_el,
constant size_t &l_out,
constant size_t &l_in,
constant size_t &c_out,
constant size_t &k_size,
constant size_t &stride,
device const T *src,
device T *dst,
uint dst_i [[ thread_position_in_grid ]]
) {
// src: (b_size, l_in, c_out, l_k)
// dst: (b_size, c_out, l_out)
if (dst_i >= dst_el) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t c_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= c_idx * dst_s1;
const int l_out_idx = tmp_dst_i;
dst[dst_i] = static_cast<T>(0);
int l_in_idx = l_out_idx / stride;
int k0 = l_out_idx - l_in_idx * stride;
// l_out_idx = l_in_idx * stride + k0
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
if (l_in_idx < l_in) {
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
dst[dst_i] += src[src_i];
}
}
}
template <typename T>
METAL_FUNC void im2col1d(
constant size_t &dst_numel,
@ -190,6 +234,21 @@ kernel void FN_NAME( \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
#define COL2IM1D_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_el, \
constant size_t &l_out, \
constant size_t &l_in, \
constant size_t &c_out, \
constant size_t &k_size, \
constant size_t &stride, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
} \
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32)
IM2COL_OP(bfloat, im2col_bf16)
#endif
COL2IM1D_OP(float, col2im1d_f32)
COL2IM1D_OP(uint8_t, col2im1d_u8)
COL2IM1D_OP(uint32_t, col2im1d_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
@ -533,4 +596,4 @@ CONVT2D_OP(float, float, conv_transpose2d_f32)
CONVT2D_OP(half, float, conv_transpose2d_f16)
#if defined(__HAVE_BFLOAT__)
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
#endif
#endif

View File

@ -739,6 +739,69 @@ pub fn call_rms_norm(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_layer_norm(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
eps: f32,
input: &Buffer,
input_offset: usize,
alpha: &Buffer,
alpha_offset: usize,
beta: &Buffer,
beta_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
length,
elements_to_sum,
(input, input_offset),
output,
(alpha, alpha_offset),
(beta, beta_offset),
eps
)
);
let out_length = length / elements_to_sum;
let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
@ -1588,6 +1651,39 @@ pub fn call_im2col1d_strided(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_col2im1d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
k_size: usize,
stride: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let l_in = shape[1];
let c_out = shape[2];
let l_out = (l_in - 1) * stride + k_size;
let dst_el = shape[0] * c_out * l_out;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,

View File

@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
}
}
template<typename T>
METAL_FUNC void layernorm(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
device const T * alpha,
device const T * beta,
constant float & eps,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
float tmp1 = 0;
float tmp2 = 0;
while (idx < stop_idx) {
tmp1 += float(src[idx]);
tmp2 += float(src[idx]) * float(src[idx]);
idx += block_dim;
}
shared_memory[tid] = tmp1;
shared_memory[tid + block_dim] = tmp2;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = shared_memory[0] / float(el_to_sum_per_block);
float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
float inv_norm = 1.0f / sqrt(var + eps);
idx = start_idx + tid;
while (idx < stop_idx) {
float val = (float(src[idx]) - mean) * inv_norm;
if (alpha != nullptr) {
val *= float(alpha[idx - start_idx]);
}
if (beta != nullptr) {
val += float(beta[idx - start_idx]);
}
dst[idx] = T(val);
idx += block_dim;
}
}
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
@ -371,6 +430,25 @@ kernel void NAME( \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
#define LAYERNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
device const T *beta, \
constant float &eps, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
LAYERNORM(layernorm_f32, float)
LAYERNORM(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
LAYERNORM(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif

View File

@ -1,5 +1,4 @@
#include <metal_stdlib>
#
using namespace metal;
METAL_FUNC uint get_strided_index(
@ -57,27 +56,31 @@ kernel void FN_NAME(
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
} \
// WHERE_OP(float, int64_t, where_i64_f32)
// WHERE_OP(double, int64_t, where_i64_f64)
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(half, uint32_t, where_u32_f16)
WHERE_OP(float, uint32_t, where_u32_f32)
WHERE_OP(uint8_t, uint32_t, where_u32_u8)
WHERE_OP(uint32_t, uint32_t, where_u32_u32)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(half, uint8_t, where_u8_f16)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
#if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64)
WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(half, int64_t, where_i64_f16)
WHERE_OP(float, int64_t, where_i64_f32)
WHERE_OP(uint8_t, int64_t, where_i64_u8)
WHERE_OP(uint32_t, int64_t, where_i64_u32)
WHERE_OP(int64_t, int64_t, where_i64_i64)
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, int64_t, where_i64_bf16)
#endif
#endif
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
#endif
WHERE_OP(bfloat, uint32_t, where_u32_bf16)
#endif

View File

@ -1023,6 +1023,27 @@ fn where_cond() {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
#[test]
fn where_cond_u32_f32() {
let shape = vec![6];
let cond = vec![0u32, 1, 0, 0, 1, 1];
let cond_l = (vec![1], 0);
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let left_l = (vec![1], 0);
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
let right_l = (vec![1], 0);
let results = run_where_cond(
&shape,
&cond,
cond_l,
&left_true,
left_l,
&right_false,
right_l,
"where_u32_f32",
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
fn run_gemm<T: Clone>(
(b, m, n, k): (usize, usize, usize, usize),

View File

@ -5,7 +5,7 @@ use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;
fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) {
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input);
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(input);
}
const B: usize = 1;

View File

@ -1,30 +1,25 @@
use candle::{DType, Device, Result, Shape, Tensor};
use candle::{Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
all_data: Tensor,
// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
// on the first call where the batch size is easily known.
// Also this makes it safe to clone a KvCache that has been reseted (as in it will not share
// its internal state with the cloned instance).
all_data: Option<Tensor>,
dim: usize,
current_seq_len: usize,
max_seq_len: usize,
}
impl Cache {
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
dim: D,
shape: S,
dtype: DType,
dev: &Device,
) -> Result<Self> {
let shape = shape.into();
let dim = dim.to_index(&shape, "kv-cache")?;
let max_seq_len = shape.dims()[dim];
let all_data = Tensor::zeros(shape, dtype, dev)?;
Ok(Self {
all_data,
pub fn new(dim: usize, max_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
current_seq_len: 0,
max_seq_len,
})
}
}
pub fn dim(&self) -> usize {
@ -39,16 +34,34 @@ impl Cache {
self.max_seq_len
}
pub fn all_data(&self) -> &Tensor {
pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
pub fn current_data(&self) -> Result<Tensor> {
self.all_data.narrow(self.dim, 0, self.current_seq_len)
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
};
Ok(data)
}
pub fn reset(&mut self) {
self.current_seq_len = 0;
self.all_data = None;
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();
if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!(
"kv-cache: above max-seq-len {}+{seq_len}>{}",
@ -56,8 +69,7 @@ impl Cache {
self.max_seq_len
)
}
self.all_data
.slice_set(src, self.dim, self.current_seq_len)?;
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
@ -70,32 +82,66 @@ pub struct KvCache {
}
impl KvCache {
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
dim: D,
shape: S,
dtype: DType,
dev: &Device,
) -> Result<Self> {
let shape = shape.into();
let dim = dim.to_index(&shape, "kv-cache")?;
let k = Cache::new(dim, &shape, dtype, dev)?;
let v = Cache::new(dim, &shape, dtype, dev)?;
Ok(Self { k, v })
pub fn new(dim: usize, max_seq_len: usize) -> Self {
let k = Cache::new(dim, max_seq_len);
let v = Cache::new(dim, max_seq_len);
Self { k, v }
}
pub fn k(&self) -> Result<Tensor> {
pub fn k_cache(&self) -> &Cache {
&self.k
}
pub fn v_cache(&self) -> &Cache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut Cache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut Cache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
pub fn v(&self) -> Result<Tensor> {
pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
let k = self.k.current_data()?;
let v = self.v.current_data()?;
let out_k = self.k.current_data()?;
let out_v = self.v.current_data()?;
let k = match out_k {
None => {
let mut shape = k.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, k.dtype(), k.device())?
}
Some(k) => k,
};
let v = match out_v {
None => {
let mut shape = v.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, v.dtype(), v.device())?
}
Some(v) => v,
};
Ok((k, v))
}
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}

View File

@ -11,8 +11,8 @@
//! use candle_nn::{LayerNorm, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(1f32, &Cpu)?;
//! let b = Tensor::new(0f32, &Cpu)?;
//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?;
//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?;
//! let layer = LayerNorm::new(w, b, 1e-5);
//!
//! let xs = Tensor::new(
@ -107,6 +107,11 @@ impl LayerNorm {
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
if x.is_contiguous() && self.remove_mean {
if let Some(bias) = self.bias.as_ref() {
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
}
}
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,

View File

@ -1,4 +1,4 @@
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
}
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?;
let xs = xs.chunk(2, D::Minus1)?;
&xs[0].silu()? * &xs[1]
}
@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(candle::D::Minus1)?;
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
}
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
let hidden_size_xs = xs.dim(candle::D::Minus1)?;
let hidden_size_xs = xs.dim(D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?;
if hidden_size_xs != hidden_size_alpha {
candle::bail!(
@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
}
#[derive(Debug, Clone)]
struct LayerNorm {
eps: f32,
}
impl candle::CustomOp3 for LayerNorm {
fn name(&self) -> &'static str {
"layer-norm"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
use candle::backend::BackendStorage;
let eps = self.eps;
fn inner<
T: candle::WithDType
+ num_traits::Float
+ num_traits::AsPrimitive<f32>
+ num_traits::FromPrimitive,
>(
src: &[T],
layout: &Layout,
alpha: &[T],
alpha_layout: &Layout,
beta: &[T],
beta_layout: &Layout,
eps: f32,
) -> Result<(CpuStorage, Shape)> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => &alpha[o1..o2],
};
let beta = match beta_layout.contiguous_offsets() {
None => candle::bail!("beta has to be contiguous"),
Some((o1, o2)) => &beta[o1..o2],
};
let el_count = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let mut dst = vec![T::zero(); el_count];
src.par_chunks(dim_m1)
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut sum = 0f32;
let mut sum2 = 0f32;
for v in src {
let v = v.as_();
sum += v;
sum2 += v * v;
}
let mean = sum / dim_m1 as f32;
let var = sum2 / dim_m1 as f32 - mean * mean;
let inv_std = (var + eps).sqrt().recip();
for ((d, s), (alpha, beta)) in
dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
{
let alpha = alpha.as_();
let beta = beta.as_();
let d_ = (s.as_() - mean) * inv_std * alpha + beta;
*d = T::from_f32(d_).unwrap_or_else(T::nan);
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, Shape::from_dims(dims)))
}
use CpuStorage as C;
match (s1, s2, s3) {
(C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
}
(C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
(C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
s3: &candle::CudaStorage,
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
use candle::{CudaDevice, WithDType};
struct S {
eps: f32,
}
impl Map3 for S {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
layout: &Layout,
alpha: &CudaSlice<T>,
alpha_layout: &Layout,
beta: &CudaSlice<T>,
beta_layout: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => alpha.slice(o1..o2),
};
let beta = match beta_layout.contiguous_offsets() {
None => candle::bail!("beta has to be contiguous"),
Some((o1, o2)) => beta.slice(o1..o2),
};
let el = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
use candle::backend::BackendStorage;
let dev = s1.device();
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
s1: &candle::MetalStorage,
l1: &Layout,
s2: &candle::MetalStorage,
l2: &Layout,
s3: &candle::MetalStorage,
l3: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = s1.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
(DType::F32, DType::F32, DType::F32) => "layernorm_f32",
(DType::F16, DType::F16, DType::F16) => "layernorm_f16",
(DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
(dt1, dt2, dt3) => {
candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
}
};
if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
candle::bail!("Non contiguous layernorm is not implemented");
}
let last_dim = l1.dims()[l1.shape().rank() - 1];
let elem_count = l1.shape().elem_count();
let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
candle_metal_kernels::call_layer_norm(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
self.eps,
s1.buffer(),
l1.start_offset() * s1.dtype().size_in_bytes(),
s2.buffer(),
l2.start_offset() * s2.dtype().size_in_bytes(),
s3.buffer(),
l3.start_offset() * s3.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
Ok((newstorage, l1.shape().clone()))
}
}
pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(alpha)?
.broadcast_add(beta)
}
pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
let hidden_size_xs = xs.dim(D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?;
let hidden_size_beta = beta.dims1()?;
if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
candle::bail!(
"shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
xs.shape(),
alpha.shape(),
beta.shape()
)
}
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
}
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;
@ -678,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
n => candle::bail!("replication-pad with a size of {n} is not supported"),
}
}
#[derive(Clone, Debug)]
pub struct Identity;
impl Identity {
pub fn new() -> Identity {
Self
}
}
impl Default for Identity {
fn default() -> Self {
Self
}
}
impl Module for Identity {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
Ok(xs.clone())
}
}

View File

@ -13,6 +13,12 @@ fn layer_norm() -> Result<()> {
let device = &Device::Cpu;
let w = Tensor::new(&[3f32], device)?;
let b = Tensor::new(&[0.5f32], device)?;
let ln2 = LayerNorm::new(Tensor::cat(&[&w, &w], 0)?, Tensor::cat(&[&b, &b], 0)?, 1e-8);
let ln3 = LayerNorm::new(
Tensor::cat(&[&w, &w, &w], 0)?,
Tensor::cat(&[&b, &b, &b], 0)?,
1e-8,
);
let ln = LayerNorm::new(w, b, 1e-8);
let two = Tensor::new(&[[[2f32]]], device)?;
@ -20,11 +26,11 @@ fn layer_norm() -> Result<()> {
assert_eq!(res.to_vec1::<f32>()?, [0.5f32]);
let inp = Tensor::new(&[[[4f32, 0f32]]], device)?;
let res = ln.forward(&inp)?;
let res = ln2.forward(&inp)?;
assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]);
let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
let res = ln.forward(&inp)?;
let res = ln3.forward(&inp)?;
assert_eq!(
test_utils::to_vec3_round(&res, 4)?,
[[
@ -35,7 +41,10 @@ fn layer_norm() -> Result<()> {
);
let mean = (res.sum_keepdim(2)? / 3.0)?;
// The average value should be `b`.
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
assert_eq!(
test_utils::to_vec3_round(&mean, 4)?,
[[[0.5], [0.5], [0.5]]]
);
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
// The standard deviation should be sqrt(`w`).
assert_eq!(

View File

@ -77,6 +77,32 @@ fn rms_norm(device: &Device) -> Result<()> {
Ok(())
}
fn layer_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?;
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
assert_eq!(
to_vec3_round(&t, 4)?,
&[
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
]
);
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
assert_eq!(
to_vec3_round(&t2, 4)?,
&[
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
]
);
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
@ -185,4 +211,5 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.5.1"
version = "0.6.0"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.5.1" }
candle-nn = { path = "../candle-nn", version = "0.5.1" }
candle = { path = "../candle-core", package = "candle-core", version = "0.6.0" }
candle-nn = { path = "../candle-nn", version = "0.6.0" }
prost = "0.12.1"
[build-dependencies]

View File

@ -1,6 +1,6 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
DataType::Bool => Some(DType::U8),
_ => None,
}
}
@ -56,6 +57,15 @@ impl Attr for str {
}
}
impl Attr for GraphProto {
const TYPE: AttributeType = AttributeType::Graph;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
attr.g
.as_ref()
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
}
}
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
@ -214,13 +224,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
mut inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
let mut values = inputs;
simple_eval_(graph, &mut inputs)
}
fn simple_eval_(
graph: &onnx::GraphProto,
values: &mut HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
@ -877,6 +893,16 @@ pub fn simple_eval(
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
"Ceil" => {
let input = get(&node.input[0])?;
let output = input.ceil()?;
values.insert(node.output[0].clone(), output);
}
"Floor" => {
let input = get(&node.input[0])?;
let output = input.floor()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
"Constant" => {
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
@ -948,6 +974,165 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
"If" => {
// protobuf encodes boolean false as 0 and true as 1
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
let attr_name = if cond != 0 {
"then_branch"
} else {
"else_branch"
};
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
if sub_graph.output.len() != node.output.len() {
bail!(
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
node.name,
sub_graph.output.len(),
node.output.len()
);
}
let branch_out = simple_eval_(sub_graph, values)?;
for (i, out) in node.output.iter().enumerate() {
values.insert(
out.clone(),
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
);
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
"Pad" => {
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
let data = get(&node.input[0])?;
let pads = get(&node.input[1])?;
if node.input.len() > 2 {
bail!(
"unsupported number of inputs {} for Pad node {:?}, expected 2",
node.input.len(),
node.name
);
}
if pads.rank() != 1 {
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
}
if pads.dim(0).unwrap() != 2 * data.rank() {
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
}
let pads = pads.to_vec1::<i64>()?;
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
match mode {
"reflect" => {
let mut out = data.clone();
for (i, &dim) in data.dims().iter().enumerate().rev() {
if pads_pre[i] == 0 && pads_post[i] == 0 {
continue;
}
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
}
let idx = if dim > 1 {
let cycle_len = dim * 2 - 1;
let skip = (pads_pre[i] as usize) % cycle_len;
let idx = zigzag(0, (dim - 1) as i64)
.skip(skip)
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
Tensor::from_iter(idx, out.device())?
} else {
Tensor::full(0i64, (dim,), out.device())?
};
out = out.index_select(&idx, i)?;
}
values.insert(node.output[0].clone(), out);
}
_ => bail!(
"unsupported 'mode' value {mode:?} for Pad node {:?}",
node.name
),
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
"Slice" => {
let data = get(&node.input[0])?;
let starts = get(&node.input[1])?;
let ends = get(&node.input[2])?;
let default_axes;
let default_steps;
let axes: &Tensor;
let steps: &Tensor;
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
// they are set to [1, ..., 1] of length len(starts)
match node.input.len() {
3 => {
let len = starts.dims()[0];
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
axes = default_axes.as_ref().unwrap();
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
4 => {
let len = starts.dims()[0];
axes = get(&node.input[3])?;
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
5 => {
steps = get(&node.input[4])?;
axes = get(&node.input[3])?;
}
_ => bail!(
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
node.input.len(),
node
),
}
let mut out = data.clone();
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
// All negative elements of axes are made non-negative by
// adding r to them, where r = rank(input).
let axis = if axis < 0 {
axis + data.rank() as i64
} else {
axis
} as usize;
let data_dim = data.dims()[axis] as i64;
let mut s = starts.get(i)?.to_scalar::<i64>()?;
let mut e = ends.get(i)?.to_scalar::<i64>()?;
// All negative values in starts[i] and ends[i] have
// dims[axes[i]] added to them, where dims are the
// dimensions of input.
if s < 0 {
s += data_dim;
}
if e < 0 {
e += data_dim;
}
let p = steps.get(i)?.to_scalar::<i64>()?;
// starts[i] is clamped into the range [0, dims[axes[i]]]
// for positive stepping and [0, dims[axes[i]]-1] for
// negative stepping.
// for positive stepping ends[axes[i]] is clamped to
// [0, dims[axes[i]]], while for negative stepping it is
// clamped to [-1, dims[axes[i]]-1].
if p >= 0 {
s = s.clamp(0, data_dim);
e = e.clamp(0, data_dim);
} else {
s = s.clamp(0, data_dim - 1);
e = e.clamp(-1, data_dim - 1);
}
let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
@ -1017,6 +1202,102 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
"ArgMin" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
axis_i64,
-rank_i64,
rank_i64 - 1
)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
.copied()
.unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmin_keepdim(axis)?
} else {
input.argmin(axis)?
}
.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"ArgMax" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
axis_i64,
-rank_i64,
rank_i64 - 1
)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
.copied()
.unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmax_keepdim(axis)?
} else {
input.argmax(axis)?
}
.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"LeakyRelu" => {
let input = get(&node.input[0])?;
let dt = input.dtype();
match dt {
DType::U8 | DType::U32 | DType::I64 => {
bail!(
"unsupported dtype {}, only float types are allowed for LeakyRelu",
dt.as_str()
)
}
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
}
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(0.01);
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
"Gemm" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let c = get(&node.input[2])?;
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(1.0);
let beta = get_attr_opt::<f32>(node, "beta")?.copied().unwrap_or(1.0);
let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?;
let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?;
let trans_a = get_attr_opt::<i64>(node, "transA")?.copied().unwrap_or(0);
let trans_b = get_attr_opt::<i64>(node, "transB")?.copied().unwrap_or(0);
let a = if trans_a == 0 { a.clone() } else { a.t()? };
let b = if trans_b == 0 { b.clone() } else { b.t()? };
let output = a
.broadcast_mul(&alpha)?
.broadcast_matmul(&b)?
.broadcast_add(&c.broadcast_mul(&beta)?)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

File diff suppressed because it is too large Load Diff

View File

@ -262,6 +262,20 @@ impl ClipEncoder {
}
Ok(xs)
}
// required by LLaVA
pub fn output_hidden_states(
&self,
xs: &Tensor,
causal_attention_mask: Option<&Tensor>,
) -> Result<Vec<Tensor>> {
let mut xs = xs.clone();
let mut hidden_states = Vec::new();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
hidden_states.push(xs.clone());
}
Ok(hidden_states)
}
}
/// A CLIP transformer based model.

View File

@ -46,6 +46,19 @@ impl ClipVisionConfig {
patch_size: 32,
}
}
pub fn clip_vit_large_patch14_336() -> Self {
Self {
embed_dim: 1024,
activation: Activation::QuickGelu,
intermediate_size: 4096,
num_hidden_layers: 24,
num_attention_heads: 16,
projection_dim: 768,
num_channels: 3,
image_size: 336,
patch_size: 14,
}
}
}
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
@ -130,6 +143,17 @@ impl ClipVisionTransformer {
pre_layer_norm,
})
}
// required by LLaVA
pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
let hidden_states = pixel_values
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
}
}
impl Module for ClipVisionTransformer {

View File

@ -0,0 +1,553 @@
use candle::D::Minus1;
use candle::{Module, Result, Tensor};
use candle_nn::ops::Identity;
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
};
use crate::models::dinov2::DinoVisionTransformer;
pub struct DepthAnythingV2Config {
out_channel_sizes: [usize; 4],
in_channel_size: usize, // embed_dim in the Dino model
num_features: usize,
use_batch_norm: bool,
use_class_token: bool,
layer_ids_vits: Vec<usize>,
input_image_size: usize,
target_patch_size: usize,
}
impl DepthAnythingV2Config {
#[allow(clippy::too_many_arguments)]
pub fn new(
out_channel_sizes: [usize; 4],
in_channel_size: usize,
num_features: usize,
use_batch_norm: bool,
use_class_token: bool,
layer_ids_vits: Vec<usize>,
input_image_size: usize,
target_patch_size: usize,
) -> Self {
Self {
out_channel_sizes,
in_channel_size,
num_features,
use_batch_norm,
use_class_token,
layer_ids_vits,
input_image_size,
target_patch_size,
}
}
pub fn vit_small() -> Self {
Self {
out_channel_sizes: [48, 96, 192, 384],
in_channel_size: 384,
num_features: 64,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![2, 5, 8, 11],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
pub fn vit_base() -> Self {
Self {
out_channel_sizes: [96, 192, 384, 768],
in_channel_size: 768,
num_features: 128,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![2, 5, 8, 11],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
pub fn vit_large() -> Self {
Self {
out_channel_sizes: [256, 512, 1024, 1024],
in_channel_size: 1024,
num_features: 256,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![4, 11, 17, 23],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
pub fn vit_giant() -> Self {
Self {
out_channel_sizes: [1536, 1536, 1536, 1536],
in_channel_size: 1536,
num_features: 384,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![9, 19, 29, 39],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
}
pub struct ResidualConvUnit {
activation: Activation,
conv1: Conv2d,
conv2: Conv2d,
batch_norm1: Option<BatchNorm>,
batch_norm2: Option<BatchNorm>,
}
impl ResidualConvUnit {
pub fn new(
conf: &DepthAnythingV2Config,
activation: Activation,
vb: VarBuilder,
) -> Result<Self> {
const KERNEL_SIZE: usize = 3;
let conv_cfg = Conv2dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
};
let conv1 = conv2d(
conf.num_features,
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("conv1"),
)?;
let conv2 = conv2d(
conf.num_features,
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("conv2"),
)?;
let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
true => {
let batch_norm_cfg = BatchNormConfig {
eps: 1e-05,
remove_mean: false,
affine: true,
momentum: 0.1,
};
(
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
)
}
false => (None, None),
};
Ok(Self {
activation,
conv1,
conv2,
batch_norm1,
batch_norm2,
})
}
}
impl Module for ResidualConvUnit {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let out = self.activation.forward(xs)?;
let out = self.conv1.forward(&out)?;
let out = if let Some(batch_norm1) = &self.batch_norm1 {
batch_norm1.forward_train(&out)?
} else {
out
};
let out = self.activation.forward(&out)?;
let out = self.conv2.forward(&out)?;
let out = if let Some(batch_norm2) = &self.batch_norm2 {
batch_norm2.forward_train(&out)?
} else {
out
};
out + xs
}
}
pub struct FeatureFusionBlock {
res_conv_unit1: ResidualConvUnit,
res_conv_unit2: ResidualConvUnit,
output_conv: Conv2d,
target_patch_size: usize,
}
impl FeatureFusionBlock {
pub fn new(
conf: &DepthAnythingV2Config,
target_patch_size: usize,
activation: Activation,
vb: VarBuilder,
) -> Result<Self> {
const KERNEL_SIZE: usize = 1;
let conv_cfg = Conv2dConfig {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
};
let output_conv = conv2d(
conf.num_features,
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("out_conv"),
)?;
let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
Ok(Self {
res_conv_unit1,
res_conv_unit2,
output_conv,
target_patch_size,
})
}
}
impl Module for FeatureFusionBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let out = self.res_conv_unit2.forward(xs)?;
let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
self.output_conv.forward(&out)
}
}
pub struct Scratch {
layer1_rn: Conv2d,
layer2_rn: Conv2d,
layer3_rn: Conv2d,
layer4_rn: Conv2d,
refine_net1: FeatureFusionBlock,
refine_net2: FeatureFusionBlock,
refine_net3: FeatureFusionBlock,
refine_net4: FeatureFusionBlock,
output_conv1: Conv2d,
output_conv2: Sequential,
}
impl Scratch {
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
const KERNEL_SIZE: usize = 3;
let conv_cfg = Conv2dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
};
let layer1_rn = conv2d_no_bias(
conf.out_channel_sizes[0],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer1_rn"),
)?;
let layer2_rn = conv2d_no_bias(
conf.out_channel_sizes[1],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer2_rn"),
)?;
let layer3_rn = conv2d_no_bias(
conf.out_channel_sizes[2],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer3_rn"),
)?;
let layer4_rn = conv2d_no_bias(
conf.out_channel_sizes[3],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer4_rn"),
)?;
let refine_net1 = FeatureFusionBlock::new(
conf,
conf.target_patch_size * 8,
Activation::Relu,
vb.pp("refinenet1"),
)?;
let refine_net2 = FeatureFusionBlock::new(
conf,
conf.target_patch_size * 4,
Activation::Relu,
vb.pp("refinenet2"),
)?;
let refine_net3 = FeatureFusionBlock::new(
conf,
conf.target_patch_size * 2,
Activation::Relu,
vb.pp("refinenet3"),
)?;
let refine_net4 = FeatureFusionBlock::new(
conf,
conf.target_patch_size,
Activation::Relu,
vb.pp("refinenet4"),
)?;
let conv_cfg = Conv2dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
};
let output_conv1 = conv2d(
conf.num_features,
conf.num_features / 2,
KERNEL_SIZE,
conv_cfg,
vb.pp("output_conv1"),
)?;
let output_conv2 = seq();
const HEAD_FEATURES_2: usize = 32;
const OUT_CHANNELS_2: usize = 1;
const KERNEL_SIZE_2: usize = 1;
let output_conv2 = output_conv2.add(conv2d(
conf.num_features / 2,
HEAD_FEATURES_2,
KERNEL_SIZE,
conv_cfg,
vb.pp("output_conv2").pp("0"),
)?);
let output_conv2 = output_conv2
.add(Activation::Relu)
.add(conv2d(
HEAD_FEATURES_2,
OUT_CHANNELS_2,
KERNEL_SIZE_2,
conv_cfg,
vb.pp("output_conv2").pp("2"),
)?)
.add(Activation::Relu);
Ok(Self {
layer1_rn,
layer2_rn,
layer3_rn,
layer4_rn,
refine_net1,
refine_net2,
refine_net3,
refine_net4,
output_conv1,
output_conv2,
})
}
}
const NUM_CHANNELS: usize = 4;
pub struct DPTHead<'a> {
conf: &'a DepthAnythingV2Config,
projections: Vec<Conv2d>,
resize_layers: Vec<Box<dyn Module>>,
readout_projections: Vec<Sequential>,
scratch: Scratch,
}
impl<'a> DPTHead<'a> {
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
projections.push(conv2d(
conf.in_channel_size,
*out_channel_size,
1,
Default::default(),
vb.pp("projects").pp(conv_index.to_string()),
)?);
}
let resize_layers: Vec<Box<dyn Module>> = vec![
Box::new(conv_transpose2d(
conf.out_channel_sizes[0],
conf.out_channel_sizes[0],
4,
ConvTranspose2dConfig {
padding: 0,
stride: 4,
dilation: 1,
output_padding: 0,
},
vb.pp("resize_layers").pp("0"),
)?),
Box::new(conv_transpose2d(
conf.out_channel_sizes[1],
conf.out_channel_sizes[1],
2,
ConvTranspose2dConfig {
padding: 0,
stride: 2,
dilation: 1,
output_padding: 0,
},
vb.pp("resize_layers").pp("1"),
)?),
Box::new(Identity::new()),
Box::new(conv2d(
conf.out_channel_sizes[3],
conf.out_channel_sizes[3],
3,
Conv2dConfig {
padding: 1,
stride: 2,
dilation: 1,
groups: 1,
},
vb.pp("resize_layers").pp("3"),
)?),
];
let readout_projections = if conf.use_class_token {
let rop = Vec::with_capacity(NUM_CHANNELS);
for rop_index in 0..NUM_CHANNELS {
seq()
.add(linear(
2 * conf.in_channel_size,
conf.in_channel_size,
vb.pp("readout_projects").pp(rop_index.to_string()),
)?)
.add(Activation::Gelu);
}
rop
} else {
vec![]
};
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
Ok(Self {
conf,
projections,
resize_layers,
readout_projections,
scratch,
})
}
}
impl Module for DPTHead<'_> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
for i in 0..NUM_CHANNELS {
let x = if self.conf.use_class_token {
let x = xs.get(i)?.get(0)?;
let class_token = xs.get(i)?.get(1)?;
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
let to_cat = [x, readout];
let cat = Tensor::cat(&to_cat, Minus1)?;
self.readout_projections[i].forward(&cat)?
} else {
xs.get(i)?
};
let x_dims = x.dims();
let x = x.permute((0, 2, 1))?.reshape((
x_dims[0],
x_dims[x_dims.len() - 1],
self.conf.target_patch_size,
self.conf.target_patch_size,
))?;
let x = self.projections[i].forward(&x)?;
let x = self.resize_layers[i].forward(&x)?;
out.push(x);
}
let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
let res3_out = self
.scratch
.refine_net3
.res_conv_unit1
.forward(&layer_3_rn)?;
let res3_out = path4.add(&res3_out)?;
let path3 = self.scratch.refine_net3.forward(&res3_out)?;
let res2_out = self
.scratch
.refine_net2
.res_conv_unit1
.forward(&layer_2_rn)?;
let res2_out = path3.add(&res2_out)?;
let path2 = self.scratch.refine_net2.forward(&res2_out)?;
let res1_out = self
.scratch
.refine_net1
.res_conv_unit1
.forward(&layer_1_rn)?;
let res1_out = path2.add(&res1_out)?;
let path1 = self.scratch.refine_net1.forward(&res1_out)?;
let out = self.scratch.output_conv1.forward(&path1)?;
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
self.scratch.output_conv2.forward(&out)
}
}
pub struct DepthAnythingV2<'a> {
pretrained: &'a DinoVisionTransformer,
depth_head: DPTHead<'a>,
conf: &'a DepthAnythingV2Config,
}
impl<'a> DepthAnythingV2<'a> {
pub fn new(
pretrained: &'a DinoVisionTransformer,
conf: &'a DepthAnythingV2Config,
vb: VarBuilder,
) -> Result<Self> {
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
Ok(Self {
pretrained,
depth_head,
conf,
})
}
}
impl<'a> Module for DepthAnythingV2<'a> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let features = self.pretrained.get_intermediate_layers(
xs,
&self.conf.layer_ids_vits,
false,
false,
true,
)?;
let depth = self.depth_head.forward(&features)?;
depth.relu()
}
}

View File

@ -258,6 +258,84 @@ impl DinoVisionTransformer {
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
fn get_intermediate_layers_not_chunked(
&self,
xs: &Tensor,
blocks_to_take: &[usize],
) -> Result<Vec<Tensor>> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
let mut output = Vec::new();
for (i, blk) in self.blocks.iter().enumerate() {
xs = blk.forward(&xs)?;
if blocks_to_take.contains(&i) {
output.push(xs.clone());
}
}
if output.len() != blocks_to_take.len() {
candle::bail!(
"only {} / {} blocks found",
output.len(),
blocks_to_take.len()
);
}
Ok(output)
}
pub fn get_intermediate_layers(
&self,
xs: &Tensor,
blocks_to_take: &[usize],
reshape: bool,
return_class_token: bool,
norm: bool,
) -> Result<Tensor> {
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
let outputs = if norm {
outputs
.iter()
.map(|out| self.norm.forward(out))
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
let class_tokens = outputs
.iter()
.map(|out| out.i((.., 0)))
.collect::<Result<Vec<_>>>()?;
let outputs = outputs
.iter()
.map(|out| out.i((.., 1..)))
.collect::<Result<Vec<_>>>()?;
let outputs = if reshape {
let (b, _c, w, h) = xs.dims4()?;
let patch_size = self.patch_embed.patch_size.0;
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
outputs
.iter()
.map(|out| {
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
.transpose(2, 3)?
.transpose(1, 2)
})
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
let outputs = if return_class_token {
outputs
.iter()
.zip(class_tokens.iter())
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
Tensor::stack(&outputs[..], 0)
}
}
impl Module for DinoVisionTransformer {

View File

@ -388,6 +388,28 @@ pub struct Llama {
}
impl Llama {
// required by LLaVA
pub fn embed(&self, x: &Tensor) -> Result<Tensor> {
self.wte.forward(x)
}
// required by LLaVA
pub fn forward_input_embed(
&self,
input_embed: &Tensor,
index_pos: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (_, seq_len, _) = input_embed.dims3()?;
let mut x = input_embed.clone();
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;

View File

@ -0,0 +1,267 @@
use std::collections::HashMap;
use crate::models::{
clip::{text_model::Activation, vision_model::ClipVisionConfig},
llama::Config,
};
use serde::{Deserialize, Serialize};
// original config from liuhaotian/llava
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LLaVAConfig {
pub architectures: Vec<String>,
pub bos_token_id: usize,
pub eos_token_id: usize,
pub hidden_size: usize,
#[serde(default = "default_image_aspect_ratio")]
pub image_aspect_ratio: String,
pub image_crop_resolution: usize,
pub image_grid_pinpoints: Vec<(u32, u32)>,
pub image_split_resolution: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub mm_hidden_size: usize,
#[serde(default = "default_mm_patch_merge_type")]
pub mm_patch_merge_type: String,
pub mm_projector_type: String,
pub mm_use_im_start_end: bool,
pub mm_vision_select_feature: String,
pub mm_vision_select_layer: isize,
pub mm_vision_tower: Option<String>,
pub model_type: String,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub pad_token_id: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub tokenizer_model_max_length: Option<usize>,
pub torch_dtype: String,
pub use_cache: bool,
pub vocab_size: usize,
#[serde(default = "default_image_token_index")]
pub image_token_index: isize,
#[serde(default = "default_hf")]
pub hf: bool,
}
fn default_hf() -> bool {
false
}
fn default_image_token_index() -> isize {
-200
}
fn default_mm_patch_merge_type() -> String {
"flat".to_string()
}
fn default_image_aspect_ratio() -> String {
"square".to_string()
}
impl LLaVAConfig {
pub fn to_llama_config(&self) -> Config {
Config {
hidden_size: self.hidden_size,
intermediate_size: self.intermediate_size,
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
num_key_value_heads: self.num_key_value_heads,
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
use_flash_attn: false,
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFLLaVATextConfig {
pub architectures: Vec<String>,
#[serde(default = "default_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_max_length")]
pub max_length: usize,
pub max_position_embeddings: usize,
pub model_type: String,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_num_hidden_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_num_key_value_heads")]
pub num_key_value_heads: usize,
pub pad_token_id: usize,
pub rms_norm_eps: f32,
#[serde(default = "default_rope_theta")]
pub rope_theta: f32,
pub torch_dtype: String,
#[serde(default = "default_use_cache")]
pub use_cache: bool,
pub vocab_size: usize,
}
fn default_num_hidden_layers() -> usize {
32
}
fn default_use_cache() -> bool {
true
}
fn default_hidden_size() -> usize {
4096
}
fn default_intermediate_size() -> usize {
11008
}
fn default_max_length() -> usize {
4096
}
fn default_num_attention_heads() -> usize {
32
}
fn default_num_key_value_heads() -> usize {
32
}
fn default_rope_theta() -> f32 {
10000.0
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFLLaVAVisionConfig {
pub hidden_size: usize,
pub image_size: usize,
pub intermediate_size: usize,
pub model_type: String,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub patch_size: usize,
pub projection_dim: usize,
pub vocab_size: usize,
}
// config from llava-v1.6-vicuna-7b-hf
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFLLaVAConfig {
pub architectures: Vec<String>,
pub ignore_index: isize,
pub image_grid_pinpoints: Vec<(u32, u32)>,
pub image_token_index: isize,
pub model_type: String,
pub projector_hidden_act: String,
pub text_config: HFLLaVATextConfig,
pub torch_dtype: String,
pub use_image_newline_parameter: bool,
pub vision_config: HFLLaVAVisionConfig,
pub vision_feature_layer: isize,
pub vision_feature_select_strategy: String,
pub vocab_size: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFGenerationConfig {
pub bos_token_id: usize,
pub eos_token_id: usize,
#[serde(default = "default_max_length")]
pub max_length: usize,
pub pad_token_id: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFPreProcessorConfig {
pub aspect_ratio_setting: String,
pub crop_size: HashMap<String, usize>,
pub do_center_crop: bool,
pub do_convert_rgb: bool,
pub do_normalize: bool,
pub do_rescale: bool,
pub do_resize: bool,
pub image_mean: Vec<f32>,
pub image_std: Vec<f32>,
pub resample: u32,
pub rescale_factor: f32,
pub size: HashMap<String, f32>,
}
impl HFLLaVAConfig {
pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
ClipVisionConfig {
embed_dim: self.vision_config.hidden_size,
activation: Activation::QuickGelu,
intermediate_size: self.vision_config.intermediate_size,
num_hidden_layers: self.vision_config.num_hidden_layers,
num_attention_heads: self.vision_config.num_attention_heads,
projection_dim: self.vision_config.projection_dim,
num_channels: 3,
image_size: self.vision_config.image_size,
patch_size: self.vision_config.patch_size,
}
}
fn map_projector_type(s: &str) -> String {
if s == "gelu" {
"mlp2x_gelu".to_string()
} else {
s.to_string()
}
}
fn map_select_feature(s: &str) -> String {
if s == "default" {
"patch".to_string()
} else {
"cls_patch".to_string()
}
}
pub fn to_llava_config(
&self,
generation_config: &HFGenerationConfig,
preprocessor_config: &HFPreProcessorConfig,
) -> LLaVAConfig {
LLaVAConfig {
hf: true,
architectures: self.architectures.clone(),
bos_token_id: generation_config.bos_token_id,
eos_token_id: generation_config.eos_token_id,
hidden_size: self.text_config.hidden_size,
image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
image_crop_resolution: 224,
image_grid_pinpoints: self.image_grid_pinpoints.clone(),
image_split_resolution: 224,
intermediate_size: self.text_config.intermediate_size,
max_position_embeddings: self.text_config.max_position_embeddings,
mm_hidden_size: 1024,
mm_patch_merge_type: "spatial_unpad".to_string(),
mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
mm_use_im_start_end: false,
mm_vision_select_feature: Self::map_select_feature(
&self.vision_feature_select_strategy,
),
mm_vision_select_layer: self.vision_feature_layer,
mm_vision_tower: None,
model_type: self.model_type.clone(),
num_attention_heads: self.text_config.num_attention_heads,
num_hidden_layers: self.text_config.num_hidden_layers,
num_key_value_heads: self.text_config.num_key_value_heads,
pad_token_id: self.text_config.pad_token_id,
rms_norm_eps: self.text_config.rms_norm_eps,
rope_theta: self.text_config.rope_theta,
tokenizer_model_max_length: Some(4096),
torch_dtype: self.torch_dtype.clone(),
use_cache: self.text_config.use_cache,
vocab_size: self.vocab_size,
image_token_index: self.image_token_index,
}
}
}

View File

@ -0,0 +1,407 @@
pub mod config;
pub mod utils;
use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
use crate::models::llama::{Cache, Llama};
use crate::models::with_tracing::linear;
use candle::{bail, Device, IndexOp, Result, Tensor};
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
use fancy_regex::Regex;
use utils::get_anyres_image_grid_shape;
use config::LLaVAConfig;
fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();
if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
if let Some(match_str) = captures.get(1) {
let match_str = match_str.as_str();
match_str.parse::<usize>().ok()
} else {
None
}
} else {
None
}
}
fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
assert_eq!(tensor.dims().len(), 3);
let (original_width, original_height) = *original_size;
let tensor_dims = tensor.dims();
let current_height = tensor_dims[1];
let current_width = tensor_dims[2];
let original_aspect_ratio = (original_width as f32) / (original_height as f32);
let current_aspect_ratio = (current_width as f32) / (current_height as f32);
if original_aspect_ratio > current_aspect_ratio {
let scale_factor = (current_width as f32) / (original_width as f32);
let new_height = (original_height as f32 * scale_factor).floor() as usize;
let padding = (current_height - new_height) / 2;
tensor.i((.., padding..current_width - padding, ..))
} else {
let scale_factor = (current_height as f32) / (original_height as f32);
let new_width = (original_width as f32 * scale_factor).floor() as usize;
let padding = (current_width - new_width) / 2;
tensor.i((.., .., padding..current_width - padding))
}
}
pub struct IdentityMap {}
impl Module for IdentityMap {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
Ok(x.clone())
}
}
pub struct MMProjector {
pub modules: Sequential,
}
impl MMProjector {
pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
if config.mm_projector_type == "linear" {
let vb_prefix = if config.hf {
"multi_modal_projector.linear_1"
} else {
"model.mm_projector.0"
};
let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
let modules = seq().add(linear);
Ok(Self { modules })
} else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
let modules = if config.hf {
let mut modules = seq().add(linear(
config.mm_hidden_size,
config.hidden_size,
vb.pp("multi_modal_projector.linear_1"),
)?);
for i in 1..mlp_depth {
modules = modules.add(Activation::Gelu).add(linear(
config.hidden_size,
config.hidden_size,
vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
)?);
}
modules
} else {
let mut modules = seq().add(linear(
config.mm_hidden_size,
config.hidden_size,
vb.pp("model.mm_projector.0"),
)?);
for i in 1..mlp_depth {
modules = modules.add(Activation::Gelu).add(linear(
config.hidden_size,
config.hidden_size,
vb.pp(format!("model.mm_projector.{}", i * 2)),
)?);
}
modules
};
Ok(Self { modules })
} else if config.mm_projector_type == "identity" {
Ok(Self {
modules: seq().add(IdentityMap {}),
})
} else {
bail!(
"Unsupported MM projector type: {}",
config.mm_projector_type
)
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.modules.forward(x)
}
}
pub struct ClipVisionTower {
model: ClipVisionTransformer,
select_layer: isize,
select_feature_method: String,
pub config: ClipVisionConfig,
}
impl ClipVisionTower {
pub fn new(
vb: VarBuilder,
select_layer: isize,
select_feature_method: &str,
config: &Option<ClipVisionConfig>,
) -> Result<Self> {
let config = if config.is_none() {
ClipVisionConfig::clip_vit_large_patch14_336()
} else {
config.clone().unwrap()
};
let select_layer = match select_layer {
-1 | -2 => select_layer,
_ => bail!("Unsupported select layer: {}", select_layer),
};
let model = ClipVisionTransformer::new(vb, &config)?;
Ok(Self {
model,
select_layer,
select_feature_method: select_feature_method.to_string(),
config,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let result = self.model.output_hidden_states(x)?;
let index = result.len() as isize + self.select_layer;
let result = result[index as usize].clone();
if self.select_feature_method == "cls_patch" {
Ok(result)
} else {
result.i((.., 1..))
}
}
pub fn num_patches_per_side(&self) -> usize {
self.config.image_size / self.config.patch_size
}
}
pub struct LLaVA {
pub clip_vision_tower: ClipVisionTower,
pub image_newline: Tensor,
pub mm_projector: MMProjector,
pub llama: Llama,
config: LLaVAConfig,
device: Device,
}
impl LLaVA {
pub fn load(
vb: VarBuilder,
config: &LLaVAConfig,
clip_vision_config: Option<ClipVisionConfig>,
) -> Result<Self> {
let device = vb.device().clone();
let llama_config = config.to_llama_config();
let mm_projector = MMProjector::load(&vb, config)?;
let (clip_vision_tower, image_newline, llama) = if config.hf {
(
ClipVisionTower::new(
vb.pp("vision_tower.vision_model"),
config.mm_vision_select_layer,
&config.mm_vision_select_feature,
&clip_vision_config,
)?,
vb.get(&[config.hidden_size], "image_newline")?
.to_device(&device)?,
Llama::load(vb.pp("language_model"), &llama_config)?,
)
} else {
(
ClipVisionTower::new(
vb.pp("model.vision_tower.vision_tower.vision_model"),
config.mm_vision_select_layer,
&config.mm_vision_select_feature,
&clip_vision_config,
)?,
vb.get(&[config.hidden_size], "model.image_newline")?
.to_device(&device)?,
Llama::load(vb, &llama_config)?,
)
};
Ok(Self {
clip_vision_tower,
image_newline,
mm_projector,
llama,
config: (*config).clone(),
device,
})
}
pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
let image_features = self.clip_vision_tower.forward(x)?;
let image_features = self.mm_projector.forward(&image_features)?;
Ok(image_features)
}
// currently only for single image, 4 dim tensor
pub fn prepare_inputs_labels_for_multimodal(
&self,
input_ids: &Tensor,
images: &[Tensor],
image_sizes: &[(u32, u32)],
) -> Result<Tensor> {
//TODO: process of multiple images/ new line
// 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576
let concat_images = Tensor::cat(images, 0)?;
let image_features_together = self.encode_images(&concat_images)?;
let split_sizes = images
.iter()
.map(|x| x.shape().dims()[0])
.collect::<Vec<usize>>();
// can be replaced by split
let mut index_pos = 0;
let mut image_features = Vec::new();
for split_size in split_sizes.iter() {
image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
index_pos += *split_size;
}
let mm_patch_merge_type = &self.config.mm_patch_merge_type;
let image_aspect_ratio = &self.config.image_aspect_ratio;
let image_features = if mm_patch_merge_type == "flat" {
image_features
.iter()
.map(|x| x.flatten(0, 1).unwrap())
.collect::<Vec<Tensor>>()
} else if mm_patch_merge_type.starts_with("spatial") {
let mut new_image_features = Vec::new();
for (image_idx, image_feature) in image_features.iter().enumerate() {
let new_image_feature = if image_feature.dims()[0] > 1 {
let base_image_feature = image_feature.get(0).unwrap();
let patch_image_feature = image_feature.i(1..).unwrap();
let height = self.clip_vision_tower.num_patches_per_side();
let width = height;
assert_eq!(height * width, base_image_feature.dims()[0]);
let image_size = image_sizes[image_idx];
let new_image_feature = if image_aspect_ratio == "anyres" {
let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
image_size,
&self.config.image_grid_pinpoints,
self.clip_vision_tower.config.image_size as u32,
);
patch_image_feature.reshape((
num_patch_height as usize,
num_patch_width as usize,
height,
width,
(),
))?
} else {
todo!("not implemented in original python LLaVA yet")
};
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
let new_image_feature = new_image_feature
.permute((4, 0, 2, 1, 3))?
.flatten(1, 2)?
.flatten(2, 3)?;
let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
let new_image_feature_dims = new_image_feature.dims();
let image_new_line = self
.image_newline
.reshape((self.config.hidden_size, 1, 1))?
.broadcast_as((
new_image_feature_dims[0],
new_image_feature_dims[1],
1,
))?;
let new_image_feature =
Tensor::cat(&[new_image_feature, image_new_line], 2)?;
new_image_feature.flatten(1, 2)?.transpose(0, 1)?
} else {
new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
};
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
} else {
let new_image_feature = image_feature.get(0).unwrap();
if mm_patch_merge_type.contains("unpad") {
Tensor::cat(
&[
new_image_feature,
self.image_newline.clone().unsqueeze(0).unwrap(),
],
0,
)
.unwrap()
} else {
new_image_feature
}
};
new_image_features.push(new_image_feature);
}
new_image_features
} else {
bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
};
// can easily be replaced by nonzero if it is implemented in candle
let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
let mut image_indices = {
let mut image_indices = vec![0_i64];
image_indices.extend(
input_ids_vec
.iter()
.enumerate()
.filter_map(|(i, x)| {
if *x == self.config.image_token_index as i64 {
Some(i as i64)
} else {
None
}
})
.collect::<Vec<i64>>(),
);
image_indices
};
if image_indices.len() == 1 {
//no image, only [0],
return self.llama.embed(input_ids);
}
let input_ids_noim = input_ids_vec
.iter()
.filter_map(|x| {
if *x != self.config.image_token_index as i64 {
Some(*x)
} else {
None
}
})
.collect::<Vec<i64>>();
let input_ids_noim_len = input_ids_noim.len();
image_indices.push((input_ids_noim_len) as i64);
let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
// can be replace by split if it is implemented in candle
let input_embed_no_ims = {
let mut input_embeds = Vec::new();
for i in 0..image_indices.len() - 1 {
let start = (image_indices[i]) as usize;
let end = image_indices[i + 1] as usize;
input_embeds.push(cur_input_embeds.i((start..end, ..))?)
}
input_embeds
};
let mut cur_new_input_embeds = Vec::new();
for (i, image_feature) in image_features.iter().enumerate() {
cur_new_input_embeds.push(input_embed_no_ims[i].clone());
cur_new_input_embeds.push(image_feature.clone());
}
cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
//trancate
let new_input_embeds =
if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
if new_input_embeds_length > tokenizer_model_max_length {
new_input_embeds.i((..tokenizer_model_max_length, ..))?
} else {
new_input_embeds
}
} else {
new_input_embeds
};
new_input_embeds.unsqueeze(0)
}
pub fn forward(
&self,
input_embeds: &Tensor,
position_id: usize,
cache: &mut Cache,
) -> Result<Tensor> {
self.llama
.forward_input_embed(input_embeds, position_id, cache)
}
}

View File

@ -0,0 +1,41 @@
pub fn get_anyres_image_grid_shape(
image_size: (u32, u32),
grid_pinpoints: &[(u32, u32)],
patch_size: u32,
) -> (u32, u32) {
let (width, height) = select_best_resolution(image_size, grid_pinpoints);
(width / patch_size, height / patch_size)
}
pub fn select_best_resolution(
original_size: (u32, u32),
possible_resolutions: &[(u32, u32)],
) -> (u32, u32) {
let (original_width, original_height) = original_size;
let mut best_fit = (0, 0);
let original_width_f = original_width as f32;
let original_height_f = original_height as f32;
let mut max_effective_resolution = 0_u32;
let mut min_wasted_resolution = u32::MAX;
for (width, height) in possible_resolutions {
let width_f = *width as f32;
let height_f = *height as f32;
let scale = (width_f / original_width_f).min(height_f / original_height_f);
let (downscaled_width, downscaled_height) = (
(original_width_f * scale) as u32,
(original_height_f * scale) as u32,
);
let effective_resolution =
std::cmp::min((*width) * (*height), downscaled_width * downscaled_height);
let wasted_resolution = (*width) * (*height) - effective_resolution;
if effective_resolution > max_effective_resolution
|| (effective_resolution == max_effective_resolution
&& wasted_resolution < min_wasted_resolution)
{
best_fit = (*width, *height);
max_effective_resolution = effective_resolution;
min_wasted_resolution = wasted_resolution;
}
}
best_fit
}

View File

@ -6,6 +6,7 @@ pub mod chatglm;
pub mod clip;
pub mod convmixer;
pub mod convnext;
pub mod depth_anything_v2;
pub mod dinov2;
pub mod distilbert;
pub mod efficientnet;
@ -17,6 +18,7 @@ pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
pub mod llava;
pub mod mamba;
pub mod marian;
pub mod metavoice;

View File

@ -3,6 +3,7 @@ use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub phi_config: PhiConfig,
pub vision_config: VisionConfig,

View File

@ -56,24 +56,20 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
dim,
sin: emb.sin()?,
cos: emb.cos()?,
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
let xs_rot = xs.i((.., .., .., ..self.dim))?;
let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., self.dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}

View File

@ -146,7 +146,7 @@ impl LayerWeights {
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?
att.matmul(&v)?
};
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attn_output.forward(&y)?;
@ -203,7 +203,6 @@ fn precomput_freqs_cis(
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
batch_size: usize,
use_flash_attn: bool,
ct: gguf_file::Content,
reader: &mut R,
@ -252,12 +251,7 @@ impl ModelWeights {
)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let kv_cache = KvCache::new(
2,
(batch_size, head_count_kv, max_seq_len, head_dim),
DType::F32,
device,
)?;
let kv_cache = KvCache::new(2, max_seq_len);
layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,

View File

@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let base_model = Model::new(cfg, vb)?;
let base_model = Model::new(cfg, vb.clone())?;
let lm_head = if vb.contains_tensor("lm_head") {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
} else {
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
};
Ok(Self {
base_model,
lm_head,

View File

@ -54,8 +54,7 @@ impl ModuleT for Vgg<'_> {
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
let layers = convs
.iter()
.enumerate()
.map(|(_, &(in_c, out_c, name))| {
.map(|&(in_c, out_c, name)| {
candle_nn::conv2d(
in_c,
out_c,