mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
44 Commits
0.5.1
...
metal-gemm
Author | SHA1 | Date | |
---|---|---|---|
9105aa4390 | |||
2a2a349fd4 | |||
c87dd386a9 | |||
f4b1597b5d | |||
ea578478d4 | |||
a226a9736b | |||
25960676ca | |||
9cd54aa5d4 | |||
eec11ce2ce | |||
9182f9f5c2 | |||
ecff05d72b | |||
7f1ba8038c | |||
74e9e41911 | |||
e27aac0a06 | |||
a3dd87f15e | |||
242e006bbb | |||
6baa1d486b | |||
36cf54525d | |||
2b10aaa05d | |||
9f804af29d | |||
54ff971e35 | |||
b9fac7ec00 | |||
f65e90e7ef | |||
d39462856b | |||
cb180eb23a | |||
9182c828e6 | |||
3f13ad3d79 | |||
cd4d941ed1 | |||
03344d3c19 | |||
1ec3b2cc18 | |||
f7773d498a | |||
7abc3b8cd7 | |||
46012ed31f | |||
f3fade3b03 | |||
ea260aeffd | |||
0814dfd148 | |||
3ceca9901a | |||
1df2bddccf | |||
6f0b807ffd | |||
d54e02d73d | |||
45e235a747 | |||
31cf64147b | |||
77ea479a18 | |||
72e7ca529a |
15
.github/workflows/trufflehog.yml
vendored
Normal file
15
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -9,6 +9,10 @@ target/
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# editor config
|
||||
.helix
|
||||
.vscode
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
|
20
Cargo.toml
20
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.5.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.5.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.5.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.6.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
|
@ -236,7 +236,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- MetaVoice-1B, text-to-speech model.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- SegFormer.
|
||||
|
@ -106,8 +106,8 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
@ -48,3 +48,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
[[bench]]
|
||||
name = "bench_main"
|
||||
harness = false
|
||||
|
||||
[[example]]
|
||||
name = "metal_basics"
|
||||
required-features = ["metal"]
|
||||
|
@ -2,11 +2,11 @@ mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
//benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
//benchmarks::random::benches,
|
||||
//benchmarks::where_cond::benches,
|
||||
//benchmarks::conv_transpose2d::benches,
|
||||
//benchmarks::qmatmul::benches,
|
||||
//benchmarks::unary::benches
|
||||
);
|
||||
|
@ -9,8 +9,10 @@ use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
@ -19,6 +21,7 @@ fn main() -> Result<()> {
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
|
28
candle-core/examples/metal_basics.rs
Normal file
28
candle-core/examples/metal_basics.rs
Normal file
@ -0,0 +1,28 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// This requires the code to be run with MTL_CAPTURE_ENABLED=1
|
||||
let device = Device::new_metal(0)?;
|
||||
let metal_device = match &device {
|
||||
Device::Metal(m) => m,
|
||||
_ => anyhow::bail!("unexpected device"),
|
||||
};
|
||||
metal_device.capture("/tmp/candle.gputrace")?;
|
||||
// This first synchronize ensures that a new command buffer gets created after setting up the
|
||||
// capture scope.
|
||||
device.synchronize()?;
|
||||
let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?;
|
||||
let x1 = x.add(&x)?;
|
||||
println!("{x1:?}");
|
||||
// This second synchronize ensures that the command buffer gets commited before the end of the
|
||||
// capture scope.
|
||||
device.synchronize()?;
|
||||
Ok(())
|
||||
}
|
@ -10,7 +10,7 @@ pub use utils::{
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -121,7 +121,8 @@ impl ReduceIndex {
|
||||
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
||||
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
||||
let dst_to_set = dst.spare_capacity_mut();
|
||||
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
|
||||
let dst_to_set =
|
||||
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let src = &src[o1..o2];
|
||||
@ -2249,7 +2250,7 @@ impl BackendStorage for CpuStorage {
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
||||
if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
|
@ -174,7 +174,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
@ -185,7 +187,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||
f_vec(
|
||||
@ -224,7 +228,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||
};
|
||||
let mut dst_i = 0;
|
||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||
f_vec(
|
||||
@ -311,7 +317,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||
};
|
||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(len) };
|
||||
@ -333,7 +341,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let ys_to_set = unsafe {
|
||||
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||
};
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
|
@ -16,7 +16,7 @@ mod error;
|
||||
mod utils;
|
||||
pub use device::{CudaDevice, DeviceId};
|
||||
pub use error::{CudaError, WrapErr};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
|
||||
|
||||
pub enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
col: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let dst_el = b_size * c_out * l_out;
|
||||
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(im)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let device = self.device().clone();
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!(
|
||||
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||
)
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col.slice, &device, &col_l)?
|
||||
} else {
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
|
||||
};
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
@ -1964,15 +2035,13 @@ unsafe fn gemm_strided_batched_bf16(
|
||||
|
||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||
let alpha = f16::from_f32(alpha_f32);
|
||||
let beta = f16::from_f32(beta_f32);
|
||||
// The type for alpha and beta depends on the computeType.
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
|
||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||
(&alpha) as *const f16 as *const _,
|
||||
(&beta) as *const f16 as *const _,
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
|
@ -54,6 +54,44 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
src3: &CudaSlice<T>,
|
||||
layout3: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn map(
|
||||
&self,
|
||||
s1: &S,
|
||||
l1: &Layout,
|
||||
s2: &S,
|
||||
l2: &Layout,
|
||||
s3: &S,
|
||||
l3: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<S> {
|
||||
let out = match (s1, s2, s3) {
|
||||
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
|
@ -273,7 +273,13 @@ impl MetalDevice {
|
||||
let descriptor = metal::CaptureDescriptor::new();
|
||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
descriptor.set_capture_device(self);
|
||||
descriptor.set_output_url(path);
|
||||
// The [set_output_url] call requires an absolute path so we convert it if needed.
|
||||
if path.as_ref().is_absolute() {
|
||||
descriptor.set_output_url(path);
|
||||
} else {
|
||||
let path = std::env::current_dir()?.join(path);
|
||||
descriptor.set_output_url(path);
|
||||
}
|
||||
|
||||
capture
|
||||
.start_capture(&descriptor)
|
||||
|
@ -718,6 +718,7 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
let name = match (self.dtype, t.dtype()) {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U32, DType::F32) => "where_u32_f32",
|
||||
(DType::U8, DType::BF16) => "where_u8_bf16",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
@ -824,44 +825,107 @@ impl BackendStorage for MetalStorage {
|
||||
k_layout: &Layout,
|
||||
params: &ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let can_use_col2im = k_layout.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
let l_out = params.l_out();
|
||||
let dst_el = params.c_out * l_out * params.b_size;
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose1d_f32",
|
||||
DType::F16 => "conv_transpose1d_f16",
|
||||
DType::BF16 => "conv_transpose1d_bf16",
|
||||
DType::U32 => "conv_transpose1d_u32",
|
||||
DType::U8 => "conv_transpose1d_u8",
|
||||
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
|
||||
let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = layout.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
layout.shape(),
|
||||
k_layout.shape()
|
||||
)
|
||||
}
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "col2im1d_f32",
|
||||
DType::U32 => "col2im1d_u32",
|
||||
DType::U8 => "col2im1d_u8",
|
||||
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
|
||||
};
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
k_layout.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
k,
|
||||
(b_size, l_in, c_out * k_size, c_in),
|
||||
&layout.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
// It is important for the command buffer to be obtained *after* the matmul
|
||||
// kernel has run, otherwise we might use a command-buffer that has been commited
|
||||
// already resulting in the following error.
|
||||
// _status < MTLCommandBufferStatusCommitted >
|
||||
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_col2im1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
&[b_size, l_in, c_out, k_size],
|
||||
params.k_size,
|
||||
params.stride,
|
||||
BufferOffset::zero_offset(&col.buffer),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
buffer
|
||||
} else {
|
||||
let buffer = self
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "conv_transpose1d_f32",
|
||||
DType::F16 => "conv_transpose1d_f16",
|
||||
DType::BF16 => "conv_transpose1d_bf16",
|
||||
DType::U32 => "conv_transpose1d_u32",
|
||||
DType::U8 => "conv_transpose1d_u8",
|
||||
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_conv_transpose1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
params.dilation,
|
||||
params.stride,
|
||||
params.padding,
|
||||
params.output_padding,
|
||||
params.c_out,
|
||||
l_out,
|
||||
params.b_size,
|
||||
layout.dims(),
|
||||
layout.stride(),
|
||||
k_layout.dims(),
|
||||
k_layout.stride(),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&k.buffer,
|
||||
k_layout.start_offset() * k.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
buffer
|
||||
};
|
||||
candle_metal_kernels::call_conv_transpose1d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
params.dilation,
|
||||
params.stride,
|
||||
params.padding,
|
||||
params.output_padding,
|
||||
params.c_out,
|
||||
l_out,
|
||||
params.b_size,
|
||||
layout.dims(),
|
||||
layout.stride(),
|
||||
k_layout.dims(),
|
||||
k_layout.stride(),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&k.buffer,
|
||||
k_layout.start_offset() * k.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
}
|
||||
|
||||
|
@ -217,10 +217,16 @@ impl Value {
|
||||
}
|
||||
}
|
||||
|
||||
/// This will also automatically upcast any integral types which will not truncate.
|
||||
pub fn to_u64(&self) -> Result<u64> {
|
||||
match self {
|
||||
Self::U64(v) => Ok(*v),
|
||||
v => crate::bail!("not a u64 {v:?}"),
|
||||
// Autoupcast cases here
|
||||
Self::U8(v) => Ok(*v as u64),
|
||||
Self::U16(v) => Ok(*v as u64),
|
||||
Self::U32(v) => Ok(*v as u64),
|
||||
Self::Bool(v) => Ok(*v as u64),
|
||||
v => crate::bail!("not a u64 or upcastable to u64 {v:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,7 +89,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
||||
|
||||
pub fn load() -> Result<crate::vision::Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "mnist".to_string();
|
||||
let dataset_id = "ylecun/mnist".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
|
@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
palette = { version = "0.7.6", optional = true }
|
||||
enterpolation = { version = "0.2.1", optional = true}
|
||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
rubato = { version = "0.15.0", optional = true }
|
||||
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["encodec"]
|
||||
|
||||
[[example]]
|
||||
name = "depth_anything_v2"
|
||||
required-features = ["depth_anything_v2"]
|
||||
|
20
candle-examples/examples/beit/README.md
Normal file
20
candle-examples/examples/beit/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-beit
|
||||
|
||||
[Beit](https://arxiv.org/abs/2106.08254) is a computer vision model.
|
||||
In this example, it is used as an ImageNet classifier: the model returns the
|
||||
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example beit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
> mountain bike, all-terrain bike, off-roader: 56.16%
|
||||
> bicycle-built-for-two, tandem bicycle, tandem: 3.08%
|
||||
> maillot : 2.23%
|
||||
> alp : 0.88%
|
||||
> crash helmet : 0.85%
|
||||
|
||||
```
|
||||
|
||||

|
79
candle-examples/examples/beit/main.rs
Normal file
79
candle-examples/examples/beit/main.rs
Normal file
@ -0,0 +1,79 @@
|
||||
//! BEiT: BERT Pre-Training of Image Transformers
|
||||
//! https://github.com/microsoft/unilm/tree/master/beit
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::beit;
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 384, 384). Beit special normalization is applied.
|
||||
pub fn load_image384_beit_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.5f32, 0.5, 0.5], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = load_image384_beit_norm(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("vincent-espitalier/candle-beit".into());
|
||||
api.get("beit_base_patch16_384.in22k_ft_in22k_in1k.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = beit::vit_base(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# candle-dinov2
|
||||
|
||||
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
|
||||
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
|
||||
|
||||
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
|
||||
|
||||
## Running an example with color map and CUDA
|
||||
|
||||
```bash
|
||||
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
```
|
||||
|
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use enterpolation::linear::ConstEquidistantLinear;
|
||||
use enterpolation::Generator;
|
||||
use palette::LinSrgb;
|
||||
|
||||
use candle::Tensor;
|
||||
|
||||
pub struct SpectralRColormap {
|
||||
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
|
||||
}
|
||||
|
||||
impl SpectralRColormap {
|
||||
pub(crate) fn new() -> Self {
|
||||
// Define a colormap similar to 'Spectral_r' by specifying key colors.
|
||||
// got the colors from ChatGPT-4o
|
||||
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
|
||||
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
|
||||
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
|
||||
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
|
||||
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
|
||||
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
|
||||
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
|
||||
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
|
||||
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
|
||||
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
|
||||
]);
|
||||
Self { gradient }
|
||||
}
|
||||
|
||||
fn get_color(&self, value: f32) -> LinSrgb {
|
||||
self.gradient.gen(value)
|
||||
}
|
||||
|
||||
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
|
||||
println!("Gray: {:?}", gray.dims());
|
||||
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
|
||||
let rgb_values: Vec<f32> = gray_values
|
||||
.iter()
|
||||
.map(|g| self.get_color(*g))
|
||||
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
|
||||
.collect();
|
||||
|
||||
let [.., height, width] = gray.dims() else {
|
||||
candle::bail!("Not enough dims!")
|
||||
};
|
||||
|
||||
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
|
||||
|
||||
color.permute((2, 0, 1))
|
||||
}
|
||||
}
|
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
@ -0,0 +1,187 @@
|
||||
//! Depth Anything V2
|
||||
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::DType::{F32, U8};
|
||||
use candle::{DType, Device, Module, Result, Tensor};
|
||||
use candle_examples::{load_image, load_image_and_resize, save_image};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
|
||||
use candle_transformers::models::dinov2;
|
||||
|
||||
use crate::color_map::SpectralRColormap;
|
||||
|
||||
mod color_map;
|
||||
|
||||
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
|
||||
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||
|
||||
const DINO_IMG_SIZE: usize = 518;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
dinov2_model: Option<PathBuf>,
|
||||
|
||||
#[arg(long)]
|
||||
depth_anything_v2_model: Option<PathBuf>,
|
||||
|
||||
#[arg(long)]
|
||||
image: PathBuf,
|
||||
|
||||
#[arg(long)]
|
||||
output_dir: Option<PathBuf>,
|
||||
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
color_map: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let dinov2_model_file = match args.dinov2_model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("lmz/candle-dino-v2".into());
|
||||
api.get("dinov2_vits14.safetensors")?
|
||||
}
|
||||
Some(dinov2_model) => dinov2_model,
|
||||
};
|
||||
println!("Using file {:?}", dinov2_model_file);
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
|
||||
let dinov2 = dinov2::vit_small(vb)?;
|
||||
println!("DinoV2 model built");
|
||||
|
||||
let depth_anything_model_file = match args.depth_anything_v2_model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
|
||||
api.get("depth_anything_v2_vits.safetensors")?
|
||||
}
|
||||
Some(depth_anything_model) => depth_anything_model,
|
||||
};
|
||||
println!("Using file {:?}", depth_anything_model_file);
|
||||
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
|
||||
};
|
||||
|
||||
let config = DepthAnythingV2Config::vit_small();
|
||||
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
|
||||
|
||||
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
|
||||
|
||||
println!("Loaded image {image:?}");
|
||||
|
||||
let depth = depth_anything.forward(&image)?;
|
||||
|
||||
println!("Got predictions {:?}", depth.shape());
|
||||
|
||||
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
|
||||
|
||||
let output_path = full_output_path(&args.image, &args.output_dir);
|
||||
println!("Saving image to {}", output_path.to_string_lossy());
|
||||
save_image(&output_image, output_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
|
||||
let input_file_name = image_path.file_name().unwrap();
|
||||
let mut output_file_name = OsString::from("depth_");
|
||||
output_file_name.push(input_file_name);
|
||||
let mut output_path = match output_dir {
|
||||
None => image_path.parent().unwrap().to_path_buf(),
|
||||
Some(output_path) => output_path.clone(),
|
||||
};
|
||||
output_path.push(output_file_name);
|
||||
|
||||
output_path
|
||||
}
|
||||
|
||||
fn load_and_prep_image(
|
||||
image_path: &PathBuf,
|
||||
device: &Device,
|
||||
) -> anyhow::Result<(usize, usize, Tensor)> {
|
||||
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
|
||||
|
||||
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
|
||||
.unsqueeze(0)?
|
||||
.to_dtype(F32)?
|
||||
.to_device(&device)?;
|
||||
|
||||
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||
.to_device(&device)?
|
||||
.broadcast_as(image.shape())?;
|
||||
let image = (image / max_pixel_val)?;
|
||||
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
|
||||
|
||||
Ok((original_height, original_width, image))
|
||||
}
|
||||
|
||||
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
|
||||
let mean_tensor =
|
||||
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||
let std_tensor =
|
||||
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||
image.sub(&mean_tensor)?.div(&std_tensor)
|
||||
}
|
||||
|
||||
fn post_process_image(
|
||||
image: &Tensor,
|
||||
original_height: usize,
|
||||
original_width: usize,
|
||||
color_map: bool,
|
||||
) -> Result<Tensor> {
|
||||
let out = image.interpolate2d(original_height, original_width)?;
|
||||
let out = scale_image(&out)?;
|
||||
|
||||
let out = if color_map {
|
||||
let spectral_r = SpectralRColormap::new();
|
||||
spectral_r.gray2color(&out)?
|
||||
} else {
|
||||
let rgb_slice = [&out, &out, &out];
|
||||
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
|
||||
};
|
||||
|
||||
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||
.to_device(out.device())?
|
||||
.broadcast_as(out.shape())?;
|
||||
let out = (out * max_pixel_val)?;
|
||||
|
||||
out.to_dtype(U8)
|
||||
}
|
||||
|
||||
fn scale_image(depth: &Tensor) -> Result<Tensor> {
|
||||
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
|
||||
|
||||
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
|
||||
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
|
||||
|
||||
let min_val_tensor = Tensor::try_from(*min_val)?
|
||||
.to_device(depth.device())?
|
||||
.broadcast_as(depth.shape())?;
|
||||
let depth = (depth - min_val_tensor)?;
|
||||
|
||||
let range = max_val - min_val;
|
||||
let range_tensor = Tensor::try_from(range)?
|
||||
.to_device(depth.device())?
|
||||
.broadcast_as(depth.shape())?;
|
||||
|
||||
depth / range_tensor
|
||||
}
|
25
candle-examples/examples/dinov2reg4/README.md
Normal file
25
candle-examples/examples/dinov2reg4/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# candle-dinov2-reg4
|
||||
|
||||
[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers.
|
||||
In this example, it is used as an plant species classifier: the model returns the
|
||||
probability for the image to belong to each of the 7806 PlantCLEF2024 categories.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
# Download classes names and a plant picture to identify
|
||||
curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt
|
||||
curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||
|
||||
# Perform inference
|
||||
cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg
|
||||
|
||||
> Orchis simia Lam. : 45.55%
|
||||
> Orchis × bergonii Nanteuil: 9.80%
|
||||
> Orchis italica Poir. : 9.66%
|
||||
> Orchis × angusticruris Franch.: 2.76%
|
||||
> Orchis × bivonae Tod. : 2.54%
|
||||
|
||||
```
|
||||
|
||||

|
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
70
candle-examples/examples/dinov2reg4/main.rs
Normal file
@ -0,0 +1,70 @@
|
||||
//! DINOv2 reg4 finetuned on PlantCLEF 2024
|
||||
//! https://arxiv.org/abs/2309.16588
|
||||
//! https://huggingface.co/spaces/BVRA/PlantCLEF2024
|
||||
//! https://zenodo.org/records/10848263
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::dinov2reg4;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt";
|
||||
let classes: Vec<String> = std::fs::read_to_string(f_species_id_mapping)
|
||||
.expect("missing classes file")
|
||||
.split('\n')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api =
|
||||
api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into());
|
||||
api.get(
|
||||
"vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors",
|
||||
)?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = dinov2reg4::vit_base(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!("{:24}: {:.2}%", classes[category_idx], 100. * pr);
|
||||
}
|
||||
Ok(())
|
||||
}
|
21
candle-examples/examples/eva2/README.md
Normal file
21
candle-examples/examples/eva2/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# candle-eva2
|
||||
|
||||
[EVA-02](https://arxiv.org/abs/2303.11331) is a computer vision model.
|
||||
In this example, it is used as an ImageNet classifier: the model returns the
|
||||
probability for the image to belong to each of the 1000 ImageNet categories.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
cargo run --example eva2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
> mountain bike, all-terrain bike, off-roader: 37.09%
|
||||
> maillot : 8.30%
|
||||
> alp : 2.13%
|
||||
> bicycle-built-for-two, tandem bicycle, tandem: 0.84%
|
||||
> crash helmet : 0.73%
|
||||
|
||||
|
||||
```
|
||||
|
||||

|
82
candle-examples/examples/eva2/main.rs
Normal file
82
candle-examples/examples/eva2/main.rs
Normal file
@ -0,0 +1,82 @@
|
||||
//! EVA-02: Explore the limits of Visual representation at scAle
|
||||
//! https://github.com/baaivision/EVA
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::eva2;
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 448, 448). OpenAI normalization is applied.
|
||||
pub fn load_image448_openai_norm<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(448, 448, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (448, 448, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let mean =
|
||||
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
|
||||
.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = load_image448_openai_norm(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("vincent-espitalier/candle-eva2".into());
|
||||
api.get("eva02_base_patch14_448.mim_in22k_ft_in22k_in1k_adapted.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
|
||||
let model = eva2::vit_base(vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
4
candle-examples/examples/llava/constants.rs
Normal file
4
candle-examples/examples/llava/constants.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
|
||||
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
|
||||
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
|
||||
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";
|
114
candle-examples/examples/llava/conversation.rs
Normal file
114
candle-examples/examples/llava/conversation.rs
Normal file
@ -0,0 +1,114 @@
|
||||
pub enum SeparatorStyle {
|
||||
Two,
|
||||
Mpt,
|
||||
}
|
||||
pub struct Conversation {
|
||||
pub system: String,
|
||||
pub roles: Vec<String>,
|
||||
pub messages: Vec<(String, Option<String>)>,
|
||||
pub offset: i32,
|
||||
pub sep_style: SeparatorStyle,
|
||||
pub sep: String,
|
||||
pub sep2: Option<String>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
pub fn new(
|
||||
system: &str,
|
||||
roles: &[String],
|
||||
offset: i32,
|
||||
sep_style: SeparatorStyle,
|
||||
sep: &str,
|
||||
sep2: Option<&str>,
|
||||
version: &str,
|
||||
) -> Self {
|
||||
Conversation {
|
||||
system: system.to_string(),
|
||||
roles: roles.to_vec(),
|
||||
messages: Vec::new(),
|
||||
offset,
|
||||
sep_style,
|
||||
sep: sep.to_string(),
|
||||
sep2: sep2.map(|s| s.to_string()),
|
||||
version: version.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv_chatml_direct() -> Self {
|
||||
Conversation::new(
|
||||
"<|im_start|>system\nAnswer the questions.",
|
||||
&[
|
||||
"<|im_start|>user\n".to_string(),
|
||||
"<|im_start|>assistant\n".to_string(),
|
||||
],
|
||||
0,
|
||||
SeparatorStyle::Mpt,
|
||||
"<|im_end|>",
|
||||
None,
|
||||
"mpt",
|
||||
)
|
||||
}
|
||||
|
||||
pub fn conv_llava_v1() -> Self {
|
||||
Conversation::new(
|
||||
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
&[
|
||||
"USER".to_string(),
|
||||
"ASSISTANT".to_string(),
|
||||
],
|
||||
0,
|
||||
SeparatorStyle::Two,
|
||||
" ",
|
||||
Some("</s>"),
|
||||
"v1"
|
||||
)
|
||||
}
|
||||
|
||||
pub fn append_message(&mut self, role: String, message: Option<&str>) {
|
||||
self.messages.push((role, message.map(|s| s.to_string())))
|
||||
}
|
||||
|
||||
pub fn append_user_message(&mut self, message: Option<&str>) {
|
||||
self.append_message(self.roles[0].clone(), message);
|
||||
}
|
||||
|
||||
pub fn append_assistant_message(&mut self, message: Option<&str>) {
|
||||
self.append_message(self.roles[1].clone(), message);
|
||||
}
|
||||
|
||||
pub fn get_prompt(&self) -> String {
|
||||
match self.sep_style {
|
||||
SeparatorStyle::Mpt => {
|
||||
let mut ret = String::new();
|
||||
ret.push_str(&self.system);
|
||||
ret.push_str(&self.sep);
|
||||
for (role, message) in &self.messages {
|
||||
ret.push_str(role);
|
||||
if let Some(message) = message {
|
||||
ret.push_str(message);
|
||||
};
|
||||
ret.push_str(&self.sep);
|
||||
}
|
||||
ret
|
||||
}
|
||||
SeparatorStyle::Two => {
|
||||
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
|
||||
let mut ret = String::new();
|
||||
ret.push_str(&self.system);
|
||||
ret.push_str(&seps[0]);
|
||||
for (i, (role, message)) in self.messages.iter().enumerate() {
|
||||
ret.push_str(role);
|
||||
if let Some(message) = message {
|
||||
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
|
||||
ret.push_str(message);
|
||||
ret.push_str(&seps[i % 2]);
|
||||
} else {
|
||||
ret.push(':')
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
317
candle-examples/examples/llava/image_processor.rs
Normal file
317
candle-examples/examples/llava/image_processor.rs
Normal file
@ -0,0 +1,317 @@
|
||||
use std::cmp::min;
|
||||
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use candle_transformers::models::llava::{
|
||||
config::{HFPreProcessorConfig, LLaVAConfig},
|
||||
utils::select_best_resolution,
|
||||
};
|
||||
use hf_hub::api::sync::Api;
|
||||
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ImageProcessor {
|
||||
#[serde(default = "default_size")]
|
||||
pub size: u32, // this is not the same as python transformer
|
||||
#[serde(default = "default_do_resize")]
|
||||
pub do_resize: bool,
|
||||
|
||||
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
|
||||
#[serde(default = "default_do_center_crop")]
|
||||
pub do_center_crop: bool,
|
||||
#[serde(default = "default_crop_size")]
|
||||
pub crop_size: u32, // this is not the same as python transformer
|
||||
#[serde(default = "default_do_rescale")]
|
||||
pub do_rescale: bool,
|
||||
#[serde(default = "default_rescale_factor")]
|
||||
pub rescale_factor: f32,
|
||||
#[serde(default = "default_do_normalize")]
|
||||
pub do_normalize: bool,
|
||||
#[serde(default = "default_image_mean")]
|
||||
pub image_mean: Vec<f32>,
|
||||
#[serde(default = "default_image_std")]
|
||||
pub image_std: Vec<f32>,
|
||||
}
|
||||
|
||||
fn default_size() -> u32 {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_do_resize() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_do_center_crop() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_crop_size() -> u32 {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_do_rescale() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_rescale_factor() -> f32 {
|
||||
1.0 / 255.0
|
||||
}
|
||||
|
||||
fn default_do_normalize() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_image_mean() -> Vec<f32> {
|
||||
vec![0.48145466, 0.4578275, 0.40821073]
|
||||
}
|
||||
|
||||
fn default_image_std() -> Vec<f32> {
|
||||
vec![0.26862954, 0.2613026, 0.2757771]
|
||||
}
|
||||
|
||||
impl ImageProcessor {
|
||||
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
|
||||
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
let api = api.model(clip_id.to_string());
|
||||
let config_filename = api
|
||||
.get("preprocessor_config.json")
|
||||
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
let image_processor =
|
||||
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
|
||||
.map_err(|e| candle::Error::Msg(e.to_string()))?;
|
||||
Ok(image_processor)
|
||||
}
|
||||
|
||||
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
|
||||
Self {
|
||||
size: hf_preprocessor_config.size["shortest_edge"] as u32,
|
||||
do_resize: hf_preprocessor_config.do_resize,
|
||||
do_center_crop: hf_preprocessor_config.do_center_crop,
|
||||
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
|
||||
do_rescale: hf_preprocessor_config.do_rescale,
|
||||
rescale_factor: hf_preprocessor_config.rescale_factor,
|
||||
do_normalize: hf_preprocessor_config.do_normalize,
|
||||
image_mean: hf_preprocessor_config.image_mean.clone(),
|
||||
image_std: hf_preprocessor_config.image_std.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
|
||||
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let size = self.size;
|
||||
if width == size && height == size {
|
||||
image.clone()
|
||||
} else {
|
||||
let (new_width, new_height) = if width < height {
|
||||
(
|
||||
size,
|
||||
(((size * height) as f32) / width as f32).ceil() as u32,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
(((size * width) as f32) / height as f32).ceil() as u32,
|
||||
size,
|
||||
)
|
||||
};
|
||||
image.resize(
|
||||
new_width,
|
||||
new_height,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let crop_size = self.crop_size;
|
||||
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
|
||||
image.crop_imm(left, top, crop_size, crop_size)
|
||||
}
|
||||
|
||||
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||
let img = image.to_rgb8().into_raw();
|
||||
let (width, height) = image.dimensions();
|
||||
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
|
||||
.to_dtype(DType::F32) // only for internal compute
|
||||
}
|
||||
|
||||
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
let rescale_factor = self.rescale_factor as f64;
|
||||
tensor.affine(rescale_factor, 0.0)
|
||||
}
|
||||
|
||||
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
let image_mean = self.image_mean.clone();
|
||||
let image_std = self.image_std.clone();
|
||||
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
|
||||
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
|
||||
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
|
||||
}
|
||||
|
||||
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
|
||||
tensor.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
|
||||
let image = if self.do_resize {
|
||||
self.resize(image)
|
||||
} else {
|
||||
image.clone()
|
||||
};
|
||||
let image = if self.do_center_crop {
|
||||
self.center_crop(&image)
|
||||
} else {
|
||||
image
|
||||
};
|
||||
let tensor = self.to_tensor(&image)?;
|
||||
let tensor = if self.do_rescale {
|
||||
self.rescale(&tensor)?
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
let tensor = if self.do_normalize {
|
||||
self.normalize(&tensor)?
|
||||
} else {
|
||||
tensor
|
||||
};
|
||||
self.to_channel_dimension_format(&tensor)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
|
||||
let (width, height) = image_size;
|
||||
let (center_width, center_height) = center_size;
|
||||
let left = if width <= center_width {
|
||||
0
|
||||
} else {
|
||||
((width as f32 - center_width as f32) / 2.0).ceil() as u32
|
||||
};
|
||||
let top = if height <= center_height {
|
||||
0
|
||||
} else {
|
||||
((height as f32 - center_height as f32) / 2.0).ceil() as u32
|
||||
};
|
||||
(left, top)
|
||||
}
|
||||
|
||||
pub fn process_image(
|
||||
image: &DynamicImage,
|
||||
processor: &ImageProcessor,
|
||||
llava_config: &LLaVAConfig,
|
||||
) -> candle::Result<Tensor> {
|
||||
if llava_config.image_aspect_ratio == *"square" {
|
||||
processor.preprocess(image)?.unsqueeze(0)
|
||||
} else if llava_config.image_aspect_ratio == *"anyres" {
|
||||
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
|
||||
} else if llava_config.image_aspect_ratio == *"pad" {
|
||||
process_pad_image(image, processor)
|
||||
} else {
|
||||
bail!("Invalid image aspect ratio")
|
||||
}
|
||||
}
|
||||
|
||||
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
|
||||
let mean_color = processor
|
||||
.image_mean
|
||||
.iter()
|
||||
.map(|x| ((*x) * 255.0) as u8)
|
||||
.collect::<Vec<u8>>();
|
||||
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
|
||||
let image_padded = expand2square(image, mean_color);
|
||||
processor.preprocess(&image_padded)
|
||||
}
|
||||
|
||||
fn process_anyres_image(
|
||||
image: &DynamicImage,
|
||||
processor: &ImageProcessor,
|
||||
grid_pinpoints: &[(u32, u32)],
|
||||
) -> Result<Tensor> {
|
||||
let original_size = image.dimensions();
|
||||
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
|
||||
let image_padded = resize_and_pad_image(image, best_resolution);
|
||||
let image_original_resize = image.resize_exact(
|
||||
processor.size,
|
||||
processor.size,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let mut patches = vec![image_original_resize];
|
||||
for patch in divide_to_patches(&image_padded, processor.crop_size) {
|
||||
patches.push(patch);
|
||||
}
|
||||
let tensors = patches
|
||||
.iter()
|
||||
.map(|patch| processor.preprocess(patch))
|
||||
.collect::<Result<Vec<Tensor>>>()?;
|
||||
Tensor::stack(&tensors, 0)
|
||||
}
|
||||
|
||||
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
|
||||
let (width, height) = image.dimensions();
|
||||
match width.cmp(&height) {
|
||||
std::cmp::Ordering::Less => {
|
||||
let mut new_image =
|
||||
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
|
||||
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
|
||||
new_image
|
||||
}
|
||||
std::cmp::Ordering::Equal => image.clone(),
|
||||
std::cmp::Ordering::Greater => {
|
||||
let mut new_image =
|
||||
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
|
||||
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
|
||||
new_image
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
|
||||
let (original_width, original_height) = image.dimensions();
|
||||
let original_width_f = original_width as f32;
|
||||
let original_height_f = original_height as f32;
|
||||
let (target_width, target_height) = target_resolution;
|
||||
let target_width_f = target_width as f32;
|
||||
let target_height_f = target_height as f32;
|
||||
let scale_w = target_width_f / original_width_f;
|
||||
let scale_h = target_height_f / original_height_f;
|
||||
let (new_width, new_height) = if scale_w < scale_h {
|
||||
(
|
||||
target_width,
|
||||
min((original_height_f * scale_w).ceil() as u32, target_height),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
min((original_width_f * scale_h).ceil() as u32, target_width),
|
||||
target_height,
|
||||
)
|
||||
};
|
||||
let resized_image = image.resize_exact(
|
||||
new_width,
|
||||
new_height,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
|
||||
let (paste_x, paste_y) =
|
||||
calculate_middle((target_width, target_height), (new_width, new_height));
|
||||
overlay(
|
||||
&mut new_image,
|
||||
&resized_image,
|
||||
paste_x.into(),
|
||||
paste_y.into(),
|
||||
);
|
||||
new_image
|
||||
}
|
||||
|
||||
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut patches = Vec::new();
|
||||
for y in (0..height).step_by(patch_size as usize) {
|
||||
for x in (0..width).step_by(patch_size as usize) {
|
||||
let patch = image.crop_imm(x, y, patch_size, patch_size);
|
||||
patches.push(patch);
|
||||
}
|
||||
}
|
||||
patches
|
||||
}
|
316
candle-examples/examples/llava/main.rs
Normal file
316
candle-examples/examples/llava/main.rs
Normal file
@ -0,0 +1,316 @@
|
||||
pub mod constants;
|
||||
pub mod conversation;
|
||||
pub mod image_processor;
|
||||
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::models::llama::Cache;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::llava::config::{
|
||||
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
|
||||
};
|
||||
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
|
||||
use clap::Parser;
|
||||
use constants::*;
|
||||
use conversation::Conversation;
|
||||
use hf_hub::api::sync::Api;
|
||||
use image_processor::{process_image, ImageProcessor};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about,long_about=None)]
|
||||
struct Args {
|
||||
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
|
||||
model_path: String,
|
||||
#[arg(long, default_value = "tokenizer/tokenizer.json")]
|
||||
tokenizer_path: String,
|
||||
#[arg(long)]
|
||||
model_base: Option<String>,
|
||||
#[arg(long)]
|
||||
image_file: String, // Required
|
||||
#[arg(long)]
|
||||
conv_mode: Option<String>,
|
||||
#[arg(long, default_value_t = 0.2)]
|
||||
temperature: f32,
|
||||
#[arg(long, default_value_t = 512)]
|
||||
max_new_tokens: usize,
|
||||
#[arg(long, action)]
|
||||
hf: bool,
|
||||
#[arg(long, action)]
|
||||
cpu: bool,
|
||||
#[arg(long, action)]
|
||||
no_kv_cache: bool,
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
|
||||
fn load_image<T: AsRef<std::path::Path>>(
|
||||
path: T,
|
||||
processor: &ImageProcessor,
|
||||
llava_config: &LLaVAConfig,
|
||||
dtype: DType,
|
||||
) -> Result<((u32, u32), Tensor)> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let img_tensor = process_image(&img, processor, llava_config)?;
|
||||
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
|
||||
}
|
||||
|
||||
fn get_model_name_from_path(model_path: &str) -> String {
|
||||
let model_paths: Vec<String> = model_path
|
||||
.trim_matches('/')
|
||||
.split('/')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
if model_paths.last().unwrap().starts_with("checkpoint-") {
|
||||
format!(
|
||||
"{}_{}",
|
||||
model_paths[model_paths.len() - 2],
|
||||
model_paths.last().unwrap()
|
||||
)
|
||||
} else {
|
||||
model_paths.last().unwrap().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
let mut res = Vec::new();
|
||||
for _ in 0..n {
|
||||
res.extend(vec.to_owned());
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
let sep = vec![sep];
|
||||
let sep = duplicate_vec(&sep, x.len());
|
||||
let mut res = x
|
||||
.iter()
|
||||
.zip(sep.iter())
|
||||
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
|
||||
.collect::<Vec<Vec<T>>>();
|
||||
res.pop();
|
||||
res
|
||||
}
|
||||
|
||||
fn tokenizer_image_token(
|
||||
prompt: &str,
|
||||
tokenizer: &Tokenizer,
|
||||
image_token_index: i64,
|
||||
llava_config: &LLaVAConfig,
|
||||
) -> Result<Tensor> {
|
||||
let prompt_chunks = prompt
|
||||
.split("<image>")
|
||||
.map(|s| {
|
||||
tokenizer
|
||||
.encode(s, true)
|
||||
.unwrap()
|
||||
.get_ids()
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|x| *x as i64)
|
||||
.collect()
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
let mut input_ids = Vec::new();
|
||||
let mut offset = 0;
|
||||
if !prompt_chunks.is_empty()
|
||||
&& !prompt_chunks[0].is_empty()
|
||||
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
|
||||
{
|
||||
offset = 1;
|
||||
input_ids.push(prompt_chunks[0][0]);
|
||||
}
|
||||
|
||||
for x in insert_separator(
|
||||
prompt_chunks,
|
||||
duplicate_vec(&[image_token_index], offset + 1),
|
||||
)
|
||||
.iter()
|
||||
{
|
||||
input_ids.extend(x[1..].to_vec())
|
||||
}
|
||||
let input_len = input_ids.len();
|
||||
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let mut args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
println!("Start loading model");
|
||||
let api = Api::new()?;
|
||||
let api = api.model(args.model_path.clone());
|
||||
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
|
||||
let config_filename = api.get("config.json")?;
|
||||
let hf_llava_config: HFLLaVAConfig =
|
||||
serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let generation_config_filename = api.get("generation_config.json")?;
|
||||
let generation_config: HFGenerationConfig =
|
||||
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
|
||||
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
|
||||
let preprocessor_config: HFPreProcessorConfig =
|
||||
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
|
||||
let llava_config =
|
||||
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let clip_vision_config = hf_llava_config.to_clip_vision_config();
|
||||
(
|
||||
llava_config,
|
||||
tokenizer,
|
||||
Some(clip_vision_config),
|
||||
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
|
||||
)
|
||||
} else {
|
||||
let config_filename = api.get("config.json")?;
|
||||
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
|
||||
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
|
||||
(
|
||||
llava_config.clone(),
|
||||
tokenizer,
|
||||
None,
|
||||
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
|
||||
)
|
||||
};
|
||||
|
||||
let llama_config = llava_config.to_llama_config();
|
||||
let dtype: DType = match llava_config.torch_dtype.as_str() {
|
||||
"float16" => DType::F16,
|
||||
"bfloat16" => DType::BF16,
|
||||
_ => bail!("unsupported dtype"),
|
||||
};
|
||||
|
||||
let eos_token_id = llava_config.eos_token_id;
|
||||
|
||||
println!("setting kv cache");
|
||||
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
|
||||
|
||||
println!("loading model weights");
|
||||
|
||||
let weight_filenames =
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
|
||||
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
|
||||
|
||||
println!("generating conv template");
|
||||
let image_token_se = format!(
|
||||
"{}{}{}",
|
||||
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
|
||||
);
|
||||
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
|
||||
if llava_config.mm_use_im_start_end {
|
||||
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
|
||||
} else {
|
||||
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
|
||||
}
|
||||
} else if llava_config.mm_use_im_start_end {
|
||||
format!("{}\n{}", image_token_se, args.prompt)
|
||||
} else {
|
||||
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
|
||||
};
|
||||
|
||||
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
|
||||
let conv_mode = if model_name.contains("llama-2") {
|
||||
"llava_llama_2"
|
||||
} else if model_name.contains("mistral") {
|
||||
"mistral_instruct"
|
||||
} else if model_name.contains("v1.6-34b") {
|
||||
"chatml_direct"
|
||||
} else if model_name.contains("v1") {
|
||||
"llava_v1"
|
||||
} else if model_name.contains("mpt") {
|
||||
"mpt"
|
||||
} else {
|
||||
"llava_v0"
|
||||
};
|
||||
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
|
||||
println!(
|
||||
"Warning: the model is trained with {}, but you are using {}",
|
||||
conv_mode,
|
||||
args.conv_mode.as_deref().unwrap()
|
||||
);
|
||||
} else {
|
||||
args.conv_mode = Some(conv_mode.to_string());
|
||||
}
|
||||
|
||||
let mut conv = match args.conv_mode {
|
||||
Some(conv_mode) => match conv_mode.as_str() {
|
||||
"chatml_direct" => Conversation::conv_chatml_direct(),
|
||||
"llava_v1" => Conversation::conv_llava_v1(),
|
||||
_ => todo!("not implement yet"),
|
||||
},
|
||||
None => bail!("conv_mode is required"),
|
||||
};
|
||||
conv.append_user_message(Some(&qs));
|
||||
conv.append_assistant_message(None);
|
||||
let prompt = conv.get_prompt();
|
||||
println!("loading image");
|
||||
let (image_size, image_tensor) =
|
||||
load_image(&args.image_file, &image_processor, &llava_config, dtype)
|
||||
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
|
||||
let image_tensor = image_tensor.to_device(&device)?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = f64::from(args.temperature);
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
Sampling::All { temperature }
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
// get input tokens
|
||||
let tokens = tokenizer_image_token(
|
||||
&prompt,
|
||||
&tokenizer,
|
||||
llava_config.image_token_index as i64,
|
||||
&llava_config,
|
||||
)?;
|
||||
let mut input_embeds =
|
||||
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
|
||||
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.max_new_tokens {
|
||||
let (_, input_embeds_len, _) = input_embeds.dims3()?;
|
||||
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
|
||||
(1, index_pos)
|
||||
} else {
|
||||
(input_embeds_len, 0)
|
||||
};
|
||||
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
|
||||
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (_, input_len, _) = input.dims3()?;
|
||||
index_pos += input_len;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
|
||||
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
|
||||
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
|
||||
if next_token == eos_token_id as u32 {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
40
candle-examples/examples/llava/readme.md
Normal file
40
candle-examples/examples/llava/readme.md
Normal file
@ -0,0 +1,40 @@
|
||||
# candle-llava
|
||||
|
||||
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
|
||||
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
|
||||
|
||||
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
|
||||
|
||||
## model zoo
|
||||
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
|
||||
* [llava-hf](https://huggingface.co/llava-hf)
|
||||
|
||||
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
|
||||
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
|
||||
|
||||
## Tokenizer Setup
|
||||
The llava-hf models contain a `tokenizer.json` file so can be used directly with
|
||||
the `-hf` command line flag.
|
||||
|
||||
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
|
||||
|
||||
```bash
|
||||
conda create -n llava python=3.10
|
||||
pip install transformers protobuf
|
||||
conda activate llava
|
||||
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
|
||||
```
|
||||
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
|
||||
|
||||
|
||||
## eval
|
||||
|
||||
```bash
|
||||
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
|
||||
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
|
||||
```
|
||||
|
||||
## Major Limitations
|
||||
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
|
||||
2. There are some ops like split, nonzero and where are not supported by candle.
|
||||
3. Lack of quantization and LoRA support.
|
18
candle-examples/examples/mobilenetv4/README.md
Normal file
18
candle-examples/examples/mobilenetv4/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# candle-mobilenetv4
|
||||
|
||||
[MobileNetV4 - Universal Models for the Mobile Ecosystem](https://arxiv.org/abs/2404.10518)
|
||||
This candle implementation uses pre-trained MobileNetV4 models from timm for inference.
|
||||
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example mobilenetv4 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which medium
|
||||
loaded image Tensor[dims 3, 256, 256; f32]
|
||||
model built
|
||||
unicycle, monocycle : 20.18%
|
||||
mountain bike, all-terrain bike, off-roader: 19.77%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 15.91%
|
||||
crash helmet : 1.15%
|
||||
tricycle, trike, velocipede: 0.67%
|
||||
```
|
106
candle-examples/examples/mobilenetv4/main.rs
Normal file
106
candle-examples/examples/mobilenetv4/main.rs
Normal file
@ -0,0 +1,106 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::mobilenetv4;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Small,
|
||||
Medium,
|
||||
Large,
|
||||
HybridMedium,
|
||||
HybridLarge,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::Small => "conv_small.e2400_r224",
|
||||
Self::Medium => "conv_medium.e500_r256",
|
||||
Self::HybridMedium => "hybrid_medium.ix_e550_r256",
|
||||
Self::Large => "conv_large.e600_r384",
|
||||
Self::HybridLarge => "hybrid_large.ix_e600_r384",
|
||||
};
|
||||
format!("timm/mobilenetv4_{}_in1k", name)
|
||||
}
|
||||
|
||||
fn resolution(&self) -> u32 {
|
||||
match self {
|
||||
Self::Small => 224,
|
||||
Self::Medium => 256,
|
||||
Self::HybridMedium => 256,
|
||||
Self::Large => 384,
|
||||
Self::HybridLarge => 384,
|
||||
}
|
||||
}
|
||||
fn config(&self) -> mobilenetv4::Config {
|
||||
match self {
|
||||
Self::Small => mobilenetv4::Config::small(),
|
||||
Self::Medium => mobilenetv4::Config::medium(),
|
||||
Self::HybridMedium => mobilenetv4::Config::hybrid_medium(),
|
||||
Self::Large => mobilenetv4::Config::large(),
|
||||
Self::HybridLarge => mobilenetv4::Config::hybrid_large(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::Small)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image(args.image, args.which.resolution())?
|
||||
.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = mobilenetv4::mobilenetv4(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -114,6 +114,10 @@ impl TextGeneration {
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if let Some(t) = self.tokenizer.decode_rest()? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -141,6 +145,8 @@ enum WhichModel {
|
||||
V2,
|
||||
#[value(name = "3")]
|
||||
V3,
|
||||
#[value(name = "3-medium")]
|
||||
V3Medium,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
@ -254,6 +260,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
||||
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -273,6 +280,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||
WhichModel::V2
|
||||
| WhichModel::V3
|
||||
| WhichModel::V3Medium
|
||||
| WhichModel::PuffinPhiV2
|
||||
| WhichModel::PhiHermes => "main".to_string(),
|
||||
}
|
||||
@ -287,7 +295,8 @@ fn main() -> Result<()> {
|
||||
| WhichModel::V1_5
|
||||
| WhichModel::V2
|
||||
| WhichModel::V2Old
|
||||
| WhichModel::V3 => repo.get("tokenizer.json")?,
|
||||
| WhichModel::V3
|
||||
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -303,14 +312,14 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
WhichModel::V3 => anyhow::bail!(
|
||||
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
|
||||
"use the quantized or quantized-phi examples for quantized phi-v3"
|
||||
),
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
|
||||
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
|
||||
candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
@ -332,7 +341,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
WhichModel::V3 => {
|
||||
WhichModel::V3 | WhichModel::V3Medium => {
|
||||
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
||||
}
|
||||
};
|
||||
@ -352,7 +361,9 @@ fn main() -> Result<()> {
|
||||
let dtype = match args.dtype {
|
||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||
None => {
|
||||
if args.model == WhichModel::V3 && device.is_cuda() {
|
||||
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
|
||||
&& device.is_cuda()
|
||||
{
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
@ -368,7 +379,7 @@ fn main() -> Result<()> {
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
Model::Phi(phi)
|
||||
}
|
||||
WhichModel::V3 => {
|
||||
WhichModel::V3 | WhichModel::V3Medium => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Phi3Config = serde_json::from_str(&config)?;
|
||||
|
@ -217,7 +217,6 @@ fn main() -> anyhow::Result<()> {
|
||||
match args.which {
|
||||
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
|
||||
1,
|
||||
args.use_flash_attn,
|
||||
model,
|
||||
&mut file,
|
||||
|
@ -144,6 +144,14 @@ enum WhichModel {
|
||||
W72b,
|
||||
#[value(name = "moe-a2.7b")]
|
||||
MoeA27b,
|
||||
#[value(name = "2-0.5b")]
|
||||
W2_0_5b,
|
||||
#[value(name = "2-1.5b")]
|
||||
W2_1_5b,
|
||||
#[value(name = "2-7b")]
|
||||
W2_7b,
|
||||
#[value(name = "2-72b")]
|
||||
W2_72b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -234,16 +242,20 @@ fn main() -> Result<()> {
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
let size = match args.model {
|
||||
WhichModel::W0_5b => "0.5B",
|
||||
WhichModel::W1_8b => "1.8B",
|
||||
WhichModel::W4b => "4B",
|
||||
WhichModel::W7b => "7B",
|
||||
WhichModel::W14b => "14B",
|
||||
WhichModel::W72b => "72B",
|
||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
||||
let (version, size) = match args.model {
|
||||
WhichModel::W2_0_5b => ("2", "0.5B"),
|
||||
WhichModel::W2_1_5b => ("2", "1.5B"),
|
||||
WhichModel::W2_7b => ("2", "7B"),
|
||||
WhichModel::W2_72b => ("2", "72B"),
|
||||
WhichModel::W0_5b => ("1.5", "0.5B"),
|
||||
WhichModel::W1_8b => ("1.5", "1.8B"),
|
||||
WhichModel::W4b => ("1.5", "4B"),
|
||||
WhichModel::W7b => ("1.5", "7B"),
|
||||
WhichModel::W14b => ("1.5", "14B"),
|
||||
WhichModel::W72b => ("1.5", "72B"),
|
||||
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||
};
|
||||
format!("Qwen/Qwen1.5-{size}")
|
||||
format!("Qwen/Qwen{version}-{size}")
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
@ -261,11 +273,15 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
WhichModel::W4b
|
||||
| WhichModel::W7b
|
||||
| WhichModel::W2_7b
|
||||
| WhichModel::W14b
|
||||
| WhichModel::W72b
|
||||
| WhichModel::W2_72b
|
||||
| WhichModel::MoeA27b => {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
|
@ -1,15 +1,16 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 224, 224). imagenet normalization is applied.
|
||||
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
/// Loads an image from disk using the image crate at the requested resolution.
|
||||
// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P, res: u32) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(224, 224, image::imageops::FilterType::Triangle);
|
||||
.resize_to_fill(res, res, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (224, 224, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||
let data = Tensor::from_vec(data, (res as usize, res as usize, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?;
|
||||
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
@ -17,6 +18,19 @@ pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
.broadcast_div(&std)
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 224, 224). imagenet normalization is applied.
|
||||
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
load_image(p, 224)
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 518, 518). imagenet normalization is applied.
|
||||
/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens).
|
||||
pub fn load_image518<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
load_image(p, 518)
|
||||
}
|
||||
|
||||
pub const CLASS_COUNT: i64 = 1000;
|
||||
|
||||
pub const CLASSES: [&str; 1000] = [
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -97,6 +97,50 @@ __device__ void im2col1d(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void col2im1d(
|
||||
const size_t dst_el,
|
||||
const size_t l_out,
|
||||
const size_t l_in,
|
||||
const size_t c_out,
|
||||
const size_t k_size,
|
||||
const size_t stride,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, l_in, c_out, l_k)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= dst_el) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= c_idx * dst_s1;
|
||||
const int l_out_idx = tmp_dst_i;
|
||||
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
|
||||
int l_in_idx = l_out_idx / stride;
|
||||
int k0 = l_out_idx - l_in_idx * stride;
|
||||
// l_out_idx = l_in_idx * stride + k0
|
||||
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||
if (l_in_idx < l_in) {
|
||||
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||
dst[dst_i] += src[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col(
|
||||
const size_t dst_numel,
|
||||
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
|
||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
|
||||
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_el, \
|
||||
const size_t l_out, \
|
||||
const size_t l_in, \
|
||||
const size_t c_out, \
|
||||
const size_t k_size, \
|
||||
const size_t stride, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16)
|
||||
COL2IM1D_OP(__half, col2im1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
||||
COL2IM1D_OP(float, col2im1d_f32)
|
||||
COL2IM1D_OP(double, col2im1d_f64)
|
||||
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||
|
@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
dst[dst_id] = shr[0];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
|
||||
template <typename T>
|
||||
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
const int block_size = blockDim.x;
|
||||
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float2 s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = mean_var;
|
||||
}
|
||||
__syncthreads();
|
||||
mean_var = s_sum[lane_id];
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
}
|
||||
|
||||
const float mean = mean_var.x / ncols;
|
||||
const float var = mean_var.y / ncols - mean * mean;
|
||||
const float inv_std = rsqrtf(var + eps);
|
||||
|
||||
if (alpha == nullptr && beta == nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs);
|
||||
}
|
||||
}
|
||||
else if (alpha == nullptr && beta != nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float b = static_cast<float>(beta[col]);
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs + b);
|
||||
}
|
||||
}
|
||||
else if (alpha != nullptr && beta == nullptr) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float a = static_cast<float>(alpha[col]);
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs * a);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
float a = static_cast<float>(alpha[col]);
|
||||
float b = static_cast<float>(beta[col]);
|
||||
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||
template <typename T>
|
||||
@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||
const TYPENAME *beta, const int n_cols, const float eps) { \
|
||||
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
|
||||
} \
|
||||
|
||||
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
||||
extern "C" __global__ void FN_NAME_I( \
|
||||
const TYPENAME *src, \
|
||||
@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
|
||||
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
|
||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||
@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
SOFTMAX_OP(__half, float, softmax_f16)
|
||||
RMSNORM_OP(__half, rmsnorm_f16)
|
||||
LAYERNORM_OP(__half, layernorm_f16)
|
||||
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
|
||||
SUM_OP(__half, sum_f16)
|
||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||
@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
|
||||
SOFTMAX_OP(double, double, softmax_f64)
|
||||
RMSNORM_OP(float, rmsnorm_f32)
|
||||
RMSNORM_OP(double, rmsnorm_f64)
|
||||
LAYERNORM_OP(float, layernorm_f32)
|
||||
LAYERNORM_OP(double, layernorm_f64)
|
||||
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
|
||||
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
|
||||
|
||||
|
1
candle-metal-kernels/.gitignore
vendored
Normal file
1
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
src/air
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
141
candle-metal-kernels/build.rs
Normal file
141
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,141 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
use std::process::Command;
|
||||
use std::{env, str};
|
||||
|
||||
const COMPILED_KERNELS: [&str; 3] = ["event", "matrix_storage", "gemm"];
|
||||
|
||||
enum Platform {
|
||||
MacOS,
|
||||
IOS,
|
||||
}
|
||||
|
||||
impl Platform {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Platform::MacOS => "macosx",
|
||||
Platform::IOS => "iphoneos",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_xcode_sdk_path(platform: Platform) -> Result<String, String> {
|
||||
let xcrun_output = Command::new("xcrun")
|
||||
.args(["--sdk", platform.as_str(), "--show-sdk-path"])
|
||||
.output()
|
||||
.expect("xcrun command failed to start");
|
||||
|
||||
Ok(str::from_utf8(&xcrun_output.stdout)
|
||||
.expect("Invalid UTF-8 from xcrun")
|
||||
.replace('\n', ""))
|
||||
}
|
||||
|
||||
fn compile_candle_metallib(sdk_path: String, bfloat_support: bool) -> Result<(), String> {
|
||||
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||
let out_dir = current_dir.join("src/libraries");
|
||||
let air_dir = current_dir.join("src/air");
|
||||
let working_directory = air_dir.display();
|
||||
let sources = current_dir.join("src/kernels");
|
||||
|
||||
// Compile metal to air
|
||||
let mut compile_air_cmd = Command::new("xcrun");
|
||||
compile_air_cmd
|
||||
.arg("metal")
|
||||
.arg(format!("-working-directory={working_directory}"))
|
||||
.arg("-Wall")
|
||||
.arg("-Wextra")
|
||||
.arg("-O3")
|
||||
.arg("-c")
|
||||
.arg("-w");
|
||||
for metal_file in COMPILED_KERNELS {
|
||||
compile_air_cmd.arg(sources.join(format!("{metal_file}.metal")));
|
||||
}
|
||||
compile_air_cmd.arg(sources.join("utils.metal"));
|
||||
compile_air_cmd.spawn().expect("Failed to compile air");
|
||||
|
||||
let mut child = compile_air_cmd.spawn().expect("Failed to compile air");
|
||||
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling metal -> air failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
let status = child
|
||||
.wait()
|
||||
.expect("Compiling metal -> air failed while waiting for result");
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling metal -> air failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Compiling metal -> air failed: {:?}", e),
|
||||
}
|
||||
|
||||
// Compile air to metallib
|
||||
let metallib = out_dir.join("candle.metallib");
|
||||
|
||||
let mut compile_metallib_cmd = Command::new("xcrun");
|
||||
compile_metallib_cmd.arg("metal").arg("-o").arg(&metallib);
|
||||
|
||||
for metal_file in COMPILED_KERNELS {
|
||||
compile_metallib_cmd.arg(air_dir.join(format!("{metal_file}.air")));
|
||||
}
|
||||
compile_metallib_cmd.arg(air_dir.join("utils.air"));
|
||||
|
||||
let mut child = compile_metallib_cmd
|
||||
.spawn()
|
||||
.expect("Failed to compile air -> metallib");
|
||||
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling air -> metallib failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
let status = child
|
||||
.wait()
|
||||
.expect("Compiling air -> metallib failed while waiting for result");
|
||||
if !status.success() {
|
||||
panic!(
|
||||
"Compiling air -> metallib failed. Exit with status: {}",
|
||||
status
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Compiling air -> metallib failed: {:?}", e),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<(), String> {
|
||||
println!("cargo::rerun-if-changed=build.rs");
|
||||
|
||||
let current_dir = env::current_dir().expect("Failed to get current directory");
|
||||
let sources = current_dir.join("src/kernels");
|
||||
|
||||
for metal_file in COMPILED_KERNELS {
|
||||
println!(
|
||||
"cargo::rerun-if-changed={}",
|
||||
sources.join(format!("{metal_file}.metal")).display()
|
||||
);
|
||||
}
|
||||
|
||||
let macos_sdk = get_xcode_sdk_path(Platform::MacOS).expect("Failed to get MacOS SDK path");
|
||||
let iphoneos_sdk = get_xcode_sdk_path(Platform::IOS).expect("Failed to get IOS SDK path");
|
||||
|
||||
compile_candle_metallib(macos_sdk, false)?;
|
||||
|
||||
Ok(())
|
||||
}
|
65
candle-metal-kernels/src/ffi.rs
Normal file
65
candle-metal-kernels/src/ffi.rs
Normal file
@ -0,0 +1,65 @@
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
|
||||
use core::ffi::{c_char, c_int, c_uint, c_void};
|
||||
|
||||
pub type CFTypeRef = *const c_void;
|
||||
pub type CFAllocatorRef = *const c_void;
|
||||
pub type CFMutableDictionaryRef = *mut c_void;
|
||||
pub type CFStringRef = *const c_void;
|
||||
pub type CFNumberRef = *const c_void;
|
||||
|
||||
pub type mach_port_t = c_uint;
|
||||
pub type kern_return_t = c_int;
|
||||
pub type io_object_t = mach_port_t;
|
||||
pub type io_iterator_t = io_object_t;
|
||||
pub type io_registry_entry_t = io_object_t;
|
||||
|
||||
pub type IOOptionBits = u32;
|
||||
pub type CFNumberType = u32;
|
||||
|
||||
pub const kIOMainPortDefault: mach_port_t = 0;
|
||||
pub const kIOServicePlane: &str = "IOService\0";
|
||||
pub const kCFNumberSInt64Type: CFNumberType = 4;
|
||||
|
||||
pub const MACH_PORT_NULL: i32 = 0;
|
||||
|
||||
#[link(name = "IOKit", kind = "framework")]
|
||||
extern "C" {
|
||||
pub fn IOServiceGetMatchingServices(
|
||||
mainPort: mach_port_t,
|
||||
matching: CFMutableDictionaryRef,
|
||||
existing: *mut io_iterator_t,
|
||||
) -> kern_return_t;
|
||||
|
||||
pub fn IOServiceMatching(a: *const c_char) -> CFMutableDictionaryRef;
|
||||
|
||||
pub fn IOIteratorNext(iterator: io_iterator_t) -> io_object_t;
|
||||
|
||||
pub fn IOObjectRelease(obj: io_object_t) -> kern_return_t;
|
||||
|
||||
pub fn IORegistryEntrySearchCFProperty(
|
||||
entry: io_registry_entry_t,
|
||||
plane: *const c_char,
|
||||
key: CFStringRef,
|
||||
allocator: CFAllocatorRef,
|
||||
options: IOOptionBits,
|
||||
) -> CFTypeRef;
|
||||
}
|
||||
|
||||
#[link(name = "CoreFoundation", kind = "framework")]
|
||||
extern "C" {
|
||||
pub fn CFNumberGetValue(
|
||||
number: CFNumberRef,
|
||||
theType: CFNumberType,
|
||||
valuePtr: *mut c_void,
|
||||
) -> bool;
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
fn __CFStringMakeConstantString(c_str: *const c_char) -> CFStringRef;
|
||||
}
|
||||
|
||||
pub fn cfstr(val: &str) -> CFStringRef {
|
||||
unsafe { __CFStringMakeConstantString(val.as_ptr().cast()) }
|
||||
}
|
122
candle-metal-kernels/src/gpu.rs
Normal file
122
candle-metal-kernels/src/gpu.rs
Normal file
@ -0,0 +1,122 @@
|
||||
use core::ffi::c_void;
|
||||
use metal::Device;
|
||||
|
||||
use crate::ffi::*;
|
||||
|
||||
const GPU_CORE_COUNT_KEY: &str = "gpu-core-count\0";
|
||||
const AGXACCELERATOR_KEY: &str = "AGXAccelerator\0";
|
||||
|
||||
struct IOIterator(io_iterator_t);
|
||||
|
||||
impl IOIterator {
|
||||
fn new(it: io_iterator_t) -> Self {
|
||||
IOIterator(it)
|
||||
}
|
||||
|
||||
fn next(&self) -> Option<io_object_t> {
|
||||
let result = unsafe { IOIteratorNext(self.0) };
|
||||
if result == MACH_PORT_NULL as u32 {
|
||||
return None;
|
||||
}
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for IOIterator {
|
||||
fn drop(&mut self) {
|
||||
unsafe { IOObjectRelease(self.0 as _) };
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn get_io_service_matching(val: &str) -> Result<CFMutableDictionaryRef, String> {
|
||||
let matching = IOServiceMatching(val.as_ptr().cast());
|
||||
if matching.is_null() {
|
||||
return Err(format!("IOServiceMatching call failed, `{val}` not found"));
|
||||
}
|
||||
Ok(matching)
|
||||
}
|
||||
|
||||
unsafe fn get_matching_services(
|
||||
main_port: mach_port_t,
|
||||
matching: CFMutableDictionaryRef,
|
||||
) -> Result<IOIterator, String> {
|
||||
let mut iterator: io_iterator_t = 0;
|
||||
let result = IOServiceGetMatchingServices(main_port, matching, &mut iterator);
|
||||
if result != 0 {
|
||||
return Err("Error getting matching services".to_string());
|
||||
}
|
||||
Ok(IOIterator::new(iterator))
|
||||
}
|
||||
|
||||
unsafe fn get_gpu_io_service() -> Result<io_object_t, String> {
|
||||
let matching = get_io_service_matching(AGXACCELERATOR_KEY)?;
|
||||
let iterator = get_matching_services(kIOMainPortDefault, matching)?;
|
||||
iterator
|
||||
.next()
|
||||
.ok_or("Error getting GPU IO Service".to_string())
|
||||
}
|
||||
|
||||
unsafe fn get_property_by_key(
|
||||
entry: io_registry_entry_t,
|
||||
plane: &str,
|
||||
key: &str,
|
||||
allocator: CFAllocatorRef,
|
||||
options: IOOptionBits,
|
||||
) -> Result<CFTypeRef, String> {
|
||||
let result = IORegistryEntrySearchCFProperty(
|
||||
entry,
|
||||
plane.as_ptr().cast(),
|
||||
cfstr(key),
|
||||
allocator,
|
||||
options,
|
||||
);
|
||||
if result.is_null() {
|
||||
return Err(format!("Error getting {key} property"));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
unsafe fn get_int_value(number: CFNumberRef) -> Result<i64, String> {
|
||||
let mut value: i64 = 0;
|
||||
let result = CFNumberGetValue(
|
||||
number,
|
||||
kCFNumberSInt64Type,
|
||||
&mut value as *mut i64 as *mut c_void,
|
||||
);
|
||||
if !result {
|
||||
return Err("Error getting int value".to_string());
|
||||
}
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
unsafe fn find_core_count() -> Result<usize, String> {
|
||||
let gpu_io_service = get_gpu_io_service()?;
|
||||
let gpu_core_count = get_property_by_key(
|
||||
gpu_io_service,
|
||||
kIOServicePlane,
|
||||
GPU_CORE_COUNT_KEY,
|
||||
core::ptr::null(),
|
||||
0,
|
||||
)?;
|
||||
let value = get_int_value(gpu_core_count as CFNumberRef)?;
|
||||
Ok(value as usize)
|
||||
}
|
||||
|
||||
pub(crate) fn get_device_core_count(device: &Device) -> usize {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
unsafe { find_core_count().expect("Retrieving gpu core count failed") }
|
||||
}
|
||||
#[cfg(target_os = "ios")]
|
||||
{
|
||||
if device.name().starts_with("A") {
|
||||
if device.supports_family(MTLGPUFamily::Apple9) {
|
||||
6
|
||||
} else {
|
||||
5
|
||||
}
|
||||
} else {
|
||||
10
|
||||
}
|
||||
}
|
||||
}
|
@ -68,6 +68,50 @@ METAL_FUNC void im2col(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void col2im1d(
|
||||
constant size_t &dst_el,
|
||||
constant size_t &l_out,
|
||||
constant size_t &l_in,
|
||||
constant size_t &c_out,
|
||||
constant size_t &k_size,
|
||||
constant size_t &stride,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint dst_i [[ thread_position_in_grid ]]
|
||||
) {
|
||||
// src: (b_size, l_in, c_out, l_k)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= dst_el) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
size_t tmp_dst_i = dst_i;
|
||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||
tmp_dst_i -= b_idx * dst_s0;
|
||||
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||
tmp_dst_i -= c_idx * dst_s1;
|
||||
const int l_out_idx = tmp_dst_i;
|
||||
|
||||
dst[dst_i] = static_cast<T>(0);
|
||||
|
||||
int l_in_idx = l_out_idx / stride;
|
||||
int k0 = l_out_idx - l_in_idx * stride;
|
||||
// l_out_idx = l_in_idx * stride + k0
|
||||
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||
if (l_in_idx < l_in) {
|
||||
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||
dst[dst_i] += src[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void im2col1d(
|
||||
constant size_t &dst_numel,
|
||||
@ -190,6 +234,21 @@ kernel void FN_NAME( \
|
||||
) { \
|
||||
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
#define COL2IM1D_OP(T, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dst_el, \
|
||||
constant size_t &l_out, \
|
||||
constant size_t &l_in, \
|
||||
constant size_t &c_out, \
|
||||
constant size_t &k_size, \
|
||||
constant size_t &stride, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
|
||||
} \
|
||||
|
||||
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32)
|
||||
IM2COL_OP(bfloat, im2col_bf16)
|
||||
#endif
|
||||
|
||||
COL2IM1D_OP(float, col2im1d_f32)
|
||||
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
@ -533,4 +596,4 @@ CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||
CONVT2D_OP(half, float, conv_transpose2d_f16)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
|
||||
#endif
|
||||
#endif
|
226
candle-metal-kernels/src/kernels/event.metal
Normal file
226
candle-metal-kernels/src/kernels/event.metal
Normal file
@ -0,0 +1,226 @@
|
||||
// -*- Metal -*-
|
||||
//===-- metal_simdgroup_event ---------------------------------------------===//
|
||||
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_EVENT
|
||||
#define __METAL_SIMDGROUP_EVENT
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %struct._simdgroup_event_t = type opaque
|
||||
//
|
||||
struct _simdgroup_event_t;
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// Bitcode: TBD
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, threadgroup void *, const device void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// Bitcode: TBD
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_1d(
|
||||
ulong, ulong, device void *, const threadgroup void *, ulong)
|
||||
__asm("air.simdgroup_async_copy_1d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p3i8.p1i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(3)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(1)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
threadgroup void *, ulong, ulong, ulong2,
|
||||
const device void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p3i8.p1i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: argmemonly convergent nounwind
|
||||
// declare %struct._simdgroup_event_t*
|
||||
// @air.simdgroup_async_copy_2d.p1i8.p3i8(
|
||||
// i64, i64,
|
||||
// i8 addrspace(1)* nocapture writeonly, i64, i64, <2 x i64>,
|
||||
// i8 addrspace(3)* nocapture readonly, i64, i64, <2 x i64>,
|
||||
// <2 x i64>, i32)
|
||||
// local_unnamed_addr #4
|
||||
//
|
||||
thread _simdgroup_event_t*
|
||||
__metal_simdgroup_async_copy_2d(
|
||||
ulong, ulong,
|
||||
device void *, ulong, ulong, ulong2,
|
||||
const threadgroup void *, ulong, ulong, ulong2,
|
||||
long2, int)
|
||||
__asm("air.simdgroup_async_copy_2d.p1i8.p3i8");
|
||||
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// ; Function Attrs: convergent nounwind
|
||||
// declare void
|
||||
// @air.wait_simdgroup_events(i32, %struct._simdgroup_event_t** nocapture)
|
||||
// local_unnamed_addr #3
|
||||
//
|
||||
void __metal_wait_simdgroup_events(
|
||||
int, thread _simdgroup_event_t**)
|
||||
__asm("air.wait_simdgroup_events");
|
||||
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
enum class simdgroup_async_copy_clamp_mode {
|
||||
clamp_to_zero = 0,
|
||||
clamp_to_edge = 1
|
||||
};
|
||||
|
||||
struct simdgroup_event {
|
||||
METAL_FUNC simdgroup_event() thread {}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
threadgroup T *dst,
|
||||
const device T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
reinterpret_cast<const device void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
device T *dst,
|
||||
const threadgroup T *src,
|
||||
ulong n_elements
|
||||
) thread {
|
||||
event = __metal_simdgroup_async_copy_1d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the arguments.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
n_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
threadgroup T *dst,
|
||||
ushort dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const device T *src,
|
||||
uint src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false,
|
||||
simdgroup_async_copy_clamp_mode clamp_mode =
|
||||
simdgroup_async_copy_clamp_mode::clamp_to_zero
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<threadgroup void *>(dst),
|
||||
ushort(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const device void *>(src),
|
||||
uint(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
static_cast<int>(clamp_mode));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void async_copy(
|
||||
// Description of the destination.
|
||||
device T *dst,
|
||||
uint dst_elements_per_row,
|
||||
ushort2 dst_tile_dimensions,
|
||||
|
||||
// Description of the source.
|
||||
const threadgroup T *src,
|
||||
ushort src_elements_per_row,
|
||||
ushort2 src_tile_dimensions,
|
||||
|
||||
// Other arguments.
|
||||
bool transpose_matrix = false
|
||||
) thread {
|
||||
if (transpose_matrix) {
|
||||
src_tile_dimensions = src_tile_dimensions.yx;
|
||||
dst_tile_dimensions = dst_tile_dimensions.yx;
|
||||
}
|
||||
event = __metal_simdgroup_async_copy_2d(
|
||||
// Description of the data type.
|
||||
sizeof(T),
|
||||
alignof(T),
|
||||
|
||||
// Description of the destination.
|
||||
reinterpret_cast<device void *>(dst),
|
||||
uint(dst_elements_per_row),
|
||||
1,
|
||||
ulong2(dst_tile_dimensions),
|
||||
|
||||
// Description of the source.
|
||||
reinterpret_cast<const threadgroup void *>(src),
|
||||
ushort(src_elements_per_row),
|
||||
1,
|
||||
ulong2(src_tile_dimensions),
|
||||
|
||||
// Other arguments.
|
||||
long2(0),
|
||||
0);
|
||||
}
|
||||
|
||||
METAL_FUNC static void wait(int count, thread simdgroup_event *events) {
|
||||
__metal_wait_simdgroup_events(
|
||||
count, reinterpret_cast<thread _simdgroup_event_t**>(events));
|
||||
}
|
||||
|
||||
private:
|
||||
// Invoking the generation of LLVM bitcode for async copies.
|
||||
//
|
||||
// %"struct.metal::simdgroup_event" = type { %struct._simdgroup_event_t* }
|
||||
//
|
||||
thread _simdgroup_event_t* event;
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif
|
538
candle-metal-kernels/src/kernels/gemm.metal
Normal file
538
candle-metal-kernels/src/kernels/gemm.metal
Normal file
@ -0,0 +1,538 @@
|
||||
// Heavily inspired by the GEMM kernels by Philip Turner:
|
||||
// https://github.com/philipturner/metal-flash-attention
|
||||
// This implementation uses generics and specialization to generate kernels for different data types instead of code generation.
|
||||
#include <metal_stdlib>
|
||||
#include "event.metal"
|
||||
#include "matrix_storage.metal"
|
||||
using namespace metal;
|
||||
|
||||
// Dimensions of each matrix.
|
||||
// - Limitations to matrix size:
|
||||
// - 2^32 in each dimension (M/N/K).
|
||||
// - TODO: Test whether the maximum dimension with correct execution is
|
||||
// actually 2^16. This will require a testing setup with non-square
|
||||
// matrices, as 65536^3 is uncomputable.
|
||||
// - Extending to 2^64 may require changing 'uint' to 'ulong'. There is a
|
||||
// good chance this will significantly degrade performance, and require
|
||||
// changing the data type of several variables that process addresses. The
|
||||
// client is responsible for ensuring correctness and performance with
|
||||
// matrices spanning several billion elements in one direction.
|
||||
// - The matrix dimensions must be known at compile time, via function
|
||||
// constants. Dynamic matrix shapes are beyond the scope of this reference
|
||||
// implementation. Dynamic shapes cause a non-negligible regression to
|
||||
// shader execution speed. However, they could minimize a compilation
|
||||
// latency bottleneck in some use cases.
|
||||
// - Limitations to batch size:
|
||||
// - Dictated by how the client modifies the code to implement batching.
|
||||
// - Dynamic batch shapes would likely not harm performance much. For example,
|
||||
// someone could enter an array of pointers/memory offsets to different
|
||||
// matrices in the batch. Each slice of a 3D thread grid could read a
|
||||
// different pointer from memory, and use that pointer as the A/B/C matrix.
|
||||
// Another approach is to restrict the input format, so all matrices are
|
||||
// stored contiguously in memory. Then, the memory offset could be computed
|
||||
// analytically from matrix size and the Z dimension in a 3D thread grid.
|
||||
//
|
||||
// Another note:
|
||||
// - The rows of the matrix must be contiguous in memory. Supporting strides
|
||||
// that differ from the actual matrix dimensions should not be difficult, but
|
||||
// it is out of scope for this reference kernel.
|
||||
constant uint M [[function_constant(0)]];
|
||||
constant uint N [[function_constant(1)]];
|
||||
constant uint K [[function_constant(2)]];
|
||||
|
||||
// Whether each matrix is transposed.
|
||||
constant bool A_trans [[function_constant(10)]];
|
||||
constant bool B_trans [[function_constant(11)]];
|
||||
|
||||
constant bool prefer_async_copy [[function_constant(206)]];
|
||||
constant bool ideal_grouping [[function_constant(207)]];
|
||||
|
||||
constant bool batched [[function_constant(100)]];
|
||||
|
||||
constant ushort A_leading_dim = A_trans ? M : K;
|
||||
constant ushort B_leading_dim = B_trans ? K : N;
|
||||
|
||||
// The layout of threads within a SIMD matrix.
|
||||
//
|
||||
// 0 0 1 1 8 8 9 9
|
||||
// 2 2 3 3 10 10 11 11
|
||||
// 4 4 5 5 12 12 13 13
|
||||
// 6 6 7 7 14 14 15 15
|
||||
// 16 16 17 17 24 24 25 25
|
||||
// 18 18 19 19 26 26 27 27
|
||||
// 20 20 21 21 28 28 29 29
|
||||
// 22 22 23 23 30 30 31 31
|
||||
//
|
||||
// This is Morton order, a method for coalescing data accesses. It is used
|
||||
// in a variety of contexts, from ray tracing acceleration structures, to
|
||||
// nodal-point Laplacians, to sorting large lattices of atoms.
|
||||
//
|
||||
// Source: https://patents.google.com/patent/US11256518B2
|
||||
METAL_FUNC ushort2 morton_order(ushort thread_index_in_simdgroup) {
|
||||
ushort lane_id = thread_index_in_simdgroup;
|
||||
ushort quad_id = lane_id / 4;
|
||||
|
||||
constexpr ushort QUADRANT_SPAN_M = 4;
|
||||
constexpr ushort THREADS_PER_QUADRANT = 8;
|
||||
ushort M_floor_of_quadrant = (quad_id / 4) * QUADRANT_SPAN_M;
|
||||
ushort M_in_quadrant = (lane_id / 2) % (THREADS_PER_QUADRANT / 2);
|
||||
ushort M_in_simd = M_floor_of_quadrant + M_in_quadrant;
|
||||
|
||||
ushort N_floor_of_quadrant = (quad_id & 2) * 2; // 0 or 4
|
||||
ushort N_in_quadrant = (lane_id % 2) * 2; // 0 or 2
|
||||
ushort N_in_simd = N_floor_of_quadrant + N_in_quadrant;
|
||||
|
||||
return ushort2(N_in_simd, M_in_simd);
|
||||
}
|
||||
|
||||
// Indexes into an array of registers.
|
||||
//
|
||||
// Calls to this function are expected to be evaluated at compile time. The
|
||||
// array indices transform into register offsets, which are embedded into the
|
||||
// assembly code.
|
||||
template <typename T>
|
||||
METAL_FUNC thread simdgroup_matrix_storage<T>* get_sram(
|
||||
thread simdgroup_matrix_storage<T> *sram,
|
||||
ushort sram_leading_dim,
|
||||
ushort2 matrix_origin
|
||||
) {
|
||||
return sram + (matrix_origin.y / 8) * (sram_leading_dim / 8) + (matrix_origin.x / 8);
|
||||
}
|
||||
|
||||
// One multiply-accumulate loop iteration, or 8 dot products.
|
||||
template<
|
||||
typename T,
|
||||
typename U = T,
|
||||
ushort M_register,
|
||||
ushort N_register
|
||||
>
|
||||
METAL_FUNC void multiply_accumulate(
|
||||
const device T *A_src,
|
||||
const device U *B_src,
|
||||
thread simdgroup_matrix_storage<T> *A_sram,
|
||||
thread simdgroup_matrix_storage<U> *B_sram,
|
||||
thread simdgroup_matrix_storage<U> *C_sram,
|
||||
ushort k
|
||||
) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
ushort2 origin(0, m);
|
||||
auto A = get_sram(A_sram, 8, origin);
|
||||
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, 0);
|
||||
auto B = get_sram(B_sram, N_register, origin);
|
||||
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
||||
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
||||
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
||||
C->multiply(*A, *B);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// One multiply-accumulate loop iteration, or 8 dot products.
|
||||
template<
|
||||
typename T,
|
||||
typename U = T,
|
||||
ushort M_register,
|
||||
ushort N_register
|
||||
>
|
||||
METAL_FUNC void multiply_accumulate(
|
||||
const threadgroup T *A_src,
|
||||
const threadgroup U *B_src,
|
||||
thread simdgroup_matrix_storage<T> *A_sram,
|
||||
thread simdgroup_matrix_storage<U> *B_sram,
|
||||
thread simdgroup_matrix_storage<U> *C_sram,
|
||||
ushort k
|
||||
) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
ushort2 origin(0, m);
|
||||
auto A = get_sram(A_sram, 8, origin);
|
||||
A->load(A_src, A_leading_dim, ushort2(k, m), A_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, 0);
|
||||
auto B = get_sram(B_sram, N_register, origin);
|
||||
B->load(B_src, B_leading_dim, ushort2(n, k), B_trans);
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
auto A = get_sram(A_sram, 8, ushort2(0, m));
|
||||
auto B = get_sram(B_sram, N_register, ushort2(n, 0));
|
||||
auto C = get_sram(C_sram, N_register, ushort2(n, m));
|
||||
C->multiply(*A, *B);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Metal function arguments.
|
||||
//
|
||||
// A: the left-hand side matrix
|
||||
// - dimensions: M x K
|
||||
// K x M (transposed)
|
||||
// - memory precision: T
|
||||
// - register precision: T
|
||||
//
|
||||
// B: the right-hand side matrix
|
||||
// - dimensions: K x N
|
||||
// N x K (transposed)
|
||||
// - memory precision: U
|
||||
// - register precision: U
|
||||
//
|
||||
// C: the output matrix, alternatively the dot product accumulator
|
||||
// - dimensions: M x N
|
||||
// - memory precision: V
|
||||
// - register precision: V
|
||||
//
|
||||
// threadgroup_block: the chunk of threadgroup memory allocated at runtime
|
||||
// - ideally 10 KB or less
|
||||
// - precision: void/8-bit integer to make the pointer arithmetic more legible
|
||||
template <
|
||||
typename T,
|
||||
typename U = T,
|
||||
typename V = U,
|
||||
ushort M_group,
|
||||
ushort N_group,
|
||||
ushort K_group,
|
||||
ushort M_splits,
|
||||
ushort N_splits,
|
||||
ushort M_register = M_group / M_splits,
|
||||
ushort N_register = N_group / N_splits
|
||||
>
|
||||
void gemm_impl(
|
||||
device T *A [[buffer(0)]],
|
||||
device U *B [[buffer(1)]],
|
||||
device V *C [[buffer(2)]],
|
||||
|
||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
const ushort A_leading_block_dim = A_trans ? M_group : K_group;
|
||||
const ushort B_leading_block_dim = B_trans ? K_group : N_group;
|
||||
|
||||
// Thresholds that mark the matrix edge.
|
||||
const uint M_edge = M - (M % M_group);
|
||||
const uint N_edge = N - (N % N_group);
|
||||
|
||||
const ushort async_iter_start = prefer_async_copy ? 0 : (K - (K % K_group));
|
||||
|
||||
// Find the number of elements in the final block. If the matrix
|
||||
// dimensions are perfectly divisibly by block dimensions, we don't want
|
||||
// this value to be zero. The final block is a full block.
|
||||
const uint M_remainder = (M % M_register == 0)
|
||||
? M_register : M % M_register;
|
||||
const ushort N_remainder = (N % N_register == 0)
|
||||
? N_register : N % N_register;
|
||||
const ushort K_remainder = (K % K_group == 0)
|
||||
? K_group : K % K_group;
|
||||
const ushort K_remainder_padded = (K_remainder + 7) / 8 * 8;
|
||||
|
||||
// Shift the final block, so it doesn't access out-of-bounds memory.
|
||||
const ushort M_shift = (M < M_group) ? 0 : M_register - M_remainder;
|
||||
const ushort N_shift = (N < N_group) ? 0 : N_register - N_remainder;
|
||||
|
||||
if (batched) {
|
||||
ulong3 offsets = matrix_offsets[0].xyz * gid.z;
|
||||
A = (device T*)((device uchar*)A + offsets[0]);
|
||||
B = (device U*)((device uchar*)B + offsets[1]);
|
||||
C = (device V*)((device uchar*)C + offsets[2]);
|
||||
}
|
||||
|
||||
auto A_block = (threadgroup T*)(threadgroup_block);
|
||||
auto B_block = (threadgroup U*)(threadgroup_block + (M * K));
|
||||
ushort2 sid(sidx % N_splits, sidx / N_splits);
|
||||
ushort2 morton_offset = morton_order(lane_id);
|
||||
|
||||
// Return early if the SIMD is out of bounds.
|
||||
//
|
||||
// There could be some threadgroups where the matrix edge cuts straight
|
||||
// through the middle of the block. SIMDs on the right or bottom of the
|
||||
// dividing line must be stopped from causing out-of-bounds accesses. This is
|
||||
// the reason for the early exit.
|
||||
uint M_offset = gid.y * M_group;
|
||||
uint N_offset = gid.x * N_group;
|
||||
if (M_offset + sid.y * M_register >= M ||
|
||||
N_offset + sid.x * N_register >= N) {
|
||||
return;
|
||||
}
|
||||
ushort2 offset_in_group(sid.x * N_register + morton_offset.x,
|
||||
sid.y * M_register + morton_offset.y);
|
||||
|
||||
// Shift the matrix block within bounds, if possible.
|
||||
if ((M_shift != 0) && (gid.y * M_group >= M_edge)) {
|
||||
M_offset -= M_shift;
|
||||
}
|
||||
if ((N_shift != 0) && (gid.x * N_group >= N_edge)) {
|
||||
N_offset -= N_shift;
|
||||
}
|
||||
|
||||
simdgroup_matrix_storage<V> C_sram[(M_register / 8) * (N_register / 8)];
|
||||
|
||||
// Initialize the accumulator.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(m, n);
|
||||
auto C = get_sram(C_sram, N_register, origin);
|
||||
*C = simdgroup_matrix_storage<V>(0);
|
||||
}
|
||||
}
|
||||
// Perform the iterations where async copy is avoided.
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint k = 0; k < async_iter_start; k += 8) {
|
||||
uint2 A_offset(k, M_offset);
|
||||
uint2 B_offset(N_offset, k);
|
||||
A_offset += uint2(morton_offset.x, offset_in_group.y);
|
||||
B_offset += uint2(offset_in_group.x, morton_offset.y);
|
||||
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
simdgroup_matrix_storage<T> A_sram[M_register / 8];
|
||||
simdgroup_matrix_storage<U> B_sram[N_register / 8];
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_src, B_src, A_sram, B_sram, C_sram, 0);
|
||||
}
|
||||
if (!prefer_async_copy) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint k = 0; k < K; k += K_group) {
|
||||
uint2 A_offset(k, M_offset);
|
||||
uint2 B_offset(N_offset, k);
|
||||
A_offset += uint2(morton_offset.x, offset_in_group.y);
|
||||
B_offset += uint2(offset_in_group.x, morton_offset.y);
|
||||
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
simdgroup_matrix_storage<T> A_sram[M_register / 8];
|
||||
simdgroup_matrix_storage<U> B_sram[N_register / 8];
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_src, B_src, A_sram, B_sram, C_sram, 0);
|
||||
}
|
||||
} else {
|
||||
// Perform the iterations where async copy is used.
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint k = async_iter_start; k < K; k += K_group) {
|
||||
// Launch an async copy from device to threadgroup memory.
|
||||
if (sidx == 0) {
|
||||
uint2 A_offset(k, M_offset);
|
||||
uint2 B_offset(N_offset, k);
|
||||
auto A_src = simdgroup_matrix_storage<T>::apply_offset(A, A_leading_dim, A_offset, A_trans);
|
||||
auto B_src = simdgroup_matrix_storage<U>::apply_offset(B, B_leading_dim, B_offset, B_trans);
|
||||
|
||||
ushort M_tile_dimension = min(uint(M_group), M - M_offset);
|
||||
ushort N_tile_dimension = min(uint(N_group), N - N_offset);
|
||||
ushort K_tile_dimension = min(uint(K_group), K - k);
|
||||
ushort K_tile_padded = min(uint(K_group), (K + K_remainder_padded - K_remainder) - k);
|
||||
|
||||
ushort2 A_tile_src(K_tile_dimension, M_tile_dimension);
|
||||
ushort2 B_tile_src(N_tile_dimension, K_tile_dimension);
|
||||
ushort2 A_tile_dst(K_tile_padded, M_tile_dimension);
|
||||
ushort2 B_tile_dst(N_tile_dimension, K_tile_padded);
|
||||
|
||||
simdgroup_event events[2];
|
||||
events[0].async_copy(A_block, A_leading_block_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans);
|
||||
events[1].async_copy(B_block, B_leading_block_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans);
|
||||
simdgroup_event::wait(2, events);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
ushort2 A_block_offset(morton_offset.x, offset_in_group.y);
|
||||
ushort2 B_block_offset(offset_in_group.x, morton_offset.y);
|
||||
auto A_block_src = simdgroup_matrix_storage<T>::apply_offset(A_block, A_leading_block_dim, A_block_offset, A_trans);
|
||||
auto B_block_src = simdgroup_matrix_storage<U>::apply_offset(B_block, B_leading_block_dim, B_block_offset, B_trans);
|
||||
|
||||
simdgroup_matrix_storage<T> A_sram[(M_register / 8) * (K_group / 8)];
|
||||
simdgroup_matrix_storage<U> B_sram[(K_group / 8) * (N_register / 8)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = 0; k < K_remainder_padded; k += 8) {
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
||||
}
|
||||
|
||||
// Will there be any iterations after this one?
|
||||
if (k + K_group < K) {
|
||||
// If so, we haven't reached the edge of either input matrix yet.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort k = K_remainder_padded; k < K_group; k += 8) {
|
||||
multiply_accumulate<T, U, M_register, N_register>(A_block_src, B_block_src, A_sram, B_sram, C_sram, k);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!prefer_async_copy && (M >= M_group) && (N >= N_group)) {
|
||||
// Fast path for matrices that qualify.
|
||||
uint2 C_offset(N_offset + offset_in_group.x,
|
||||
M_offset + offset_in_group.y);
|
||||
auto C_dst = simdgroup_matrix_storage<U>::apply_offset(
|
||||
C, N, C_offset);
|
||||
|
||||
// Write the accumulator to device memory.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto C = get_sram(C_sram, N_register, origin);
|
||||
C->store(C_dst, N, origin);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Slow path for when memory must be handled more carefully.
|
||||
auto C_block = (threadgroup V*)(threadgroup_block);
|
||||
auto C_block_dst = simdgroup_matrix_storage<V>::apply_offset(C_block, N_group, offset_in_group);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write the accumulator to threadgroup memory.
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort m = 0; m < M_register; m += 8) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (ushort n = 0; n < N_register; n += 8) {
|
||||
ushort2 origin(n, m);
|
||||
auto C = get_sram(C_sram, N_register, origin);
|
||||
C->store(C_block_dst, N_group, origin);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Launch the async copy from threadgroup to device memory.
|
||||
if (sidx == 0) {
|
||||
uint2 C_offset(gid.x * N_group, gid.y * M_group);
|
||||
ushort2 C_tile(min(uint(N_group), N - C_offset.x),
|
||||
min(uint(M_group), M - C_offset.y));
|
||||
auto C_dst = simdgroup_matrix_storage<V>::apply_offset(C, N, C_offset);
|
||||
|
||||
// If we shift successfully, the garbage zone moves from the bottom right
|
||||
// to the top left.
|
||||
if ((M_shift != 0) || (N_shift != 0)) {
|
||||
ushort2 C_block_shift(0, 0);
|
||||
if ((M_shift != 0) && (C_offset.y >= M_edge)) {
|
||||
C_block_shift.y = M_shift;
|
||||
}
|
||||
if ((N_shift != 0) && (C_offset.x >= N_edge)) {
|
||||
C_block_shift.x = N_shift;
|
||||
}
|
||||
C_block = simdgroup_matrix_storage<V>::apply_offset(C_block, N_group, C_block_shift);
|
||||
}
|
||||
|
||||
simdgroup_event event;
|
||||
event.async_copy(C_dst, N, C_tile, C_block, N_group, C_tile);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void hgemm(
|
||||
device half *A [[buffer(0)]],
|
||||
device half *B [[buffer(1)]],
|
||||
device half *C [[buffer(2)]],
|
||||
|
||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
if (ideal_grouping) {
|
||||
gemm_impl<half, half, half, 32, 32, 32, 1, 1>(
|
||||
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||
);
|
||||
} else {
|
||||
gemm_impl<half, half, half, 48, 48, 32, 1, 1>(
|
||||
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void sgemm(
|
||||
device float *A [[buffer(0)]],
|
||||
device float *B [[buffer(1)]],
|
||||
device float *C [[buffer(2)]],
|
||||
|
||||
threadgroup uchar *threadgroup_block [[threadgroup(0)]],
|
||||
constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]],
|
||||
|
||||
uint3 gid [[threadgroup_position_in_grid]],
|
||||
ushort sidx [[simdgroup_index_in_threadgroup]],
|
||||
ushort lane_id [[thread_index_in_simdgroup]]
|
||||
) {
|
||||
gemm_impl<float, float, float, 32, 32, 32, 2, 2>(
|
||||
A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id
|
||||
);
|
||||
/*
|
||||
if (prefer_async_copy) {
|
||||
constexpr ushort M_split = 1;
|
||||
constexpr ushort N_split = 1;
|
||||
if (ideal_grouping) {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
32,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
} else {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
48,
|
||||
48,
|
||||
24,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
}
|
||||
} else {
|
||||
constexpr ushort M_split = 2;
|
||||
constexpr ushort N_split = 2;
|
||||
if (ideal_grouping) {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
8,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
} else {
|
||||
gemm_impl<
|
||||
float,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
100,
|
||||
M_split,
|
||||
N_split
|
||||
>(
|
||||
A, B, C, threadgroup_block, gid, sidx, lane_id
|
||||
);
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
243
candle-metal-kernels/src/kernels/matrix_storage.metal
Normal file
243
candle-metal-kernels/src/kernels/matrix_storage.metal
Normal file
@ -0,0 +1,243 @@
|
||||
// -*- Metal -*-
|
||||
//===-- metal_simdgroup_matrix_storage ------------------------------------===//
|
||||
// Copyright (c) 2024 Philip Turner. See MIT LICENSE
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
#define __METAL_SIMDGROUP_MATRIX_STORAGE
|
||||
|
||||
#pragma METAL internals : enable
|
||||
namespace metal
|
||||
{
|
||||
template <typename T>
|
||||
struct simdgroup_matrix_storage {
|
||||
typedef vec<T, 64> storage_type;
|
||||
|
||||
storage_type t;
|
||||
|
||||
METAL_FUNC thread vec<T, 2>* thread_elements() thread {
|
||||
return reinterpret_cast<thread vec<T, 2>*>(&t);
|
||||
}
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage() thread = default;
|
||||
|
||||
METAL_FUNC simdgroup_matrix_storage(vec<T, 2> thread_elements) thread {
|
||||
*(this->thread_elements()) = thread_elements;
|
||||
}
|
||||
|
||||
METAL_FUNC static device T* apply_offset(device T *src, uint elements_per_row, uint2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + ulong(matrix_origin.x * elements_per_row) + matrix_origin.y;
|
||||
} else {
|
||||
return src + ulong(matrix_origin.y * elements_per_row) + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC static threadgroup T* apply_offset(threadgroup T *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
return src + matrix_origin.x * elements_per_row + matrix_origin.y;
|
||||
} else {
|
||||
return src + matrix_origin.y * elements_per_row + matrix_origin.x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void load(const device U *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
vec<U, 2> memoryForm = *(const device vec<U, 2>*)(src + combinedAddress);
|
||||
*(thread_elements()) = vec<T, 2>(memoryForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void load_bfloat(const device bfloat *src, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
bfloat memoryForm0 = src[address0];
|
||||
bfloat memoryForm1 = src[address1];
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[1] = memoryForm0;
|
||||
registerForm[3] = memoryForm1;
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
bfloat2 memoryForm = *(const device packed_bfloat2*)(src + combinedAddress);
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm);
|
||||
((thread bfloat*)®isterForm)[1] = memoryForm[0];
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void load(const threadgroup U *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
U memoryForm0 = src[address0];
|
||||
U memoryForm1 = src[address1];
|
||||
((thread T*)thread_elements())[0] = T(memoryForm0);
|
||||
((thread T*)thread_elements())[1] = T(memoryForm1);
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
vec<U, 2> memoryForm = *(const threadgroup vec<U, 2>*)(src + combinedAddress);
|
||||
*(thread_elements()) = vec<T, 2>(memoryForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void load_bfloat(const threadgroup bfloat *src, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
bfloat memoryForm0 = src[address0];
|
||||
bfloat memoryForm1 = src[address1];
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[1] = memoryForm0;
|
||||
registerForm[3] = memoryForm1;
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
bfloat2 memoryForm = *(const threadgroup packed_bfloat2*)(src + combinedAddress);
|
||||
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
((thread float*)®isterForm)[1] = *(thread float*)(&memoryForm);
|
||||
((thread bfloat*)®isterForm)[1] = memoryForm[0];
|
||||
((thread bfloat4*)thread_elements())[0] = registerForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void store(device U *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
vec<T, 2> registerForm = *(thread_elements());
|
||||
*(device vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void store_bfloat(device bfloat *dst, uint elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
uint address0 = uint(matrix_origin.x + 0) * elements_per_row + uint(matrix_origin.y);
|
||||
uint address1 = uint(matrix_origin.x + 1) * elements_per_row + uint(matrix_origin.y);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
uint address0 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
uint address1 = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 1);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else {
|
||||
auto combinedAddress = uint(matrix_origin.y) * elements_per_row + uint(matrix_origin.x + 0);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
float memoryForm = ((thread float*)®isterForm)[1];
|
||||
*(device float*)(dst + combinedAddress) = memoryForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
METAL_FUNC void store(threadgroup U *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
T registerForm0 = ((thread T*)thread_elements())[0];
|
||||
T registerForm1 = ((thread T*)thread_elements())[1];
|
||||
dst[address0] = U(registerForm0);
|
||||
dst[address1] = U(registerForm1);
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
vec<T, 2> registerForm = *(thread_elements());
|
||||
*(threadgroup vec<U, 2>*)(dst + combinedAddress) = vec<U, 2>(registerForm);
|
||||
}
|
||||
}
|
||||
|
||||
// WARNING: 'T' must be 'float'.
|
||||
METAL_FUNC void store_bfloat(threadgroup bfloat *dst, ushort elements_per_row, ushort2 matrix_origin, bool transpose_matrix = false) {
|
||||
if (transpose_matrix) {
|
||||
ushort address0 = ushort(matrix_origin.x + 0) * elements_per_row + ushort(matrix_origin.y);
|
||||
ushort address1 = ushort(matrix_origin.x + 1) * elements_per_row + ushort(matrix_origin.y);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else if (elements_per_row % 2 != 0) {
|
||||
ushort address0 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
ushort address1 = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 1);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
dst[address0] = registerForm[2];
|
||||
dst[address1] = registerForm[3];
|
||||
} else {
|
||||
auto combinedAddress = ushort(matrix_origin.y) * elements_per_row + ushort(matrix_origin.x + 0);
|
||||
bfloat4 registerForm = *(thread bfloat4*)(thread_elements());
|
||||
registerForm[2] = registerForm[1];
|
||||
float memoryForm = ((thread float*)®isterForm)[1];
|
||||
*(threadgroup float*)(dst + combinedAddress) = memoryForm;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
METAL_FUNC void multiply(simdgroup_matrix_storage<U> a, simdgroup_matrix_storage<V> b, bool accumulate = true) {
|
||||
if (!accumulate) {
|
||||
*(thread_elements()) = vec<T, 2>(0);
|
||||
}
|
||||
t = __metal_simdgroup_matrix_8x8_multiply_accumulate(a.t, b.t, t, typename simdgroup_matrix_storage<T>::storage_type());
|
||||
}
|
||||
};
|
||||
} // namespace metal
|
||||
#pragma METAL internals : disable
|
||||
|
||||
#endif
|
@ -104,7 +104,7 @@ METAL_FUNC void argmax(
|
||||
threadgroup T * shared_memory,
|
||||
threadgroup uint * shared_indices
|
||||
) {
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
@ -173,7 +173,7 @@ METAL_FUNC void reduce(
|
||||
threadgroup T * shared_memory,
|
||||
T (*fn)(T, T)
|
||||
) {
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
||||
// to (dst_id + 1) * el_to_sum_per_block.
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||
@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void layernorm(
|
||||
constant size_t & src_numel,
|
||||
constant size_t & el_to_sum_per_block,
|
||||
device const T * src,
|
||||
device T * dst,
|
||||
device const T * alpha,
|
||||
device const T * beta,
|
||||
constant float & eps,
|
||||
uint id,
|
||||
uint tid,
|
||||
uint dst_id,
|
||||
uint block_dim,
|
||||
threadgroup float * shared_memory
|
||||
) {
|
||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||
size_t idx = start_idx + tid;
|
||||
|
||||
float tmp1 = 0;
|
||||
float tmp2 = 0;
|
||||
while (idx < stop_idx) {
|
||||
tmp1 += float(src[idx]);
|
||||
tmp2 += float(src[idx]) * float(src[idx]);
|
||||
idx += block_dim;
|
||||
}
|
||||
shared_memory[tid] = tmp1;
|
||||
shared_memory[tid + block_dim] = tmp2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
|
||||
shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
/* wait for shared_memory[0] to be filled */
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float mean = shared_memory[0] / float(el_to_sum_per_block);
|
||||
float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
|
||||
float inv_norm = 1.0f / sqrt(var + eps);
|
||||
idx = start_idx + tid;
|
||||
while (idx < stop_idx) {
|
||||
float val = (float(src[idx]) - mean) * inv_norm;
|
||||
if (alpha != nullptr) {
|
||||
val *= float(alpha[idx - start_idx]);
|
||||
}
|
||||
if (beta != nullptr) {
|
||||
val += float(beta[idx - start_idx]);
|
||||
}
|
||||
dst[idx] = T(val);
|
||||
idx += block_dim;
|
||||
}
|
||||
}
|
||||
|
||||
#define RMSNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
@ -371,6 +430,25 @@ kernel void NAME( \
|
||||
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
#define LAYERNORM(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
device const T *alpha, \
|
||||
device const T *beta, \
|
||||
constant float &eps, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = 0; \
|
||||
layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC void ropei(
|
||||
constant size_t &bh,
|
||||
@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
RMSNORM(rmsnorm_f32, float)
|
||||
RMSNORM(rmsnorm_f16, half)
|
||||
LAYERNORM(layernorm_f32, float)
|
||||
LAYERNORM(layernorm_f16, half)
|
||||
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
|
||||
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
|
||||
|
||||
@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
RMSNORM(rmsnorm_bf16, bfloat)
|
||||
LAYERNORM(layernorm_bf16, bfloat)
|
||||
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
|
||||
#endif
|
@ -1,5 +1,4 @@
|
||||
#include <metal_stdlib>
|
||||
#
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
@ -57,27 +56,31 @@ kernel void FN_NAME(
|
||||
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||
} \
|
||||
|
||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
||||
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||
//
|
||||
// WHERE_OP(float, uint32_t, where_u32_f32)
|
||||
// WHERE_OP(double, uint32_t, where_u32_f64)
|
||||
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
WHERE_OP(half, uint32_t, where_u32_f16)
|
||||
WHERE_OP(float, uint32_t, where_u32_f32)
|
||||
WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||
WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||
WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||
|
||||
WHERE_OP(half, int64_t, where_i64_f16)
|
||||
WHERE_OP(float, int64_t, where_i64_f32)
|
||||
WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||
WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||
WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, int64_t, where_i64_bf16)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||
#endif
|
||||
WHERE_OP(bfloat, uint32_t, where_u32_bf16)
|
||||
#endif
|
47
candle-metal-kernels/src/kernels/utils.metal
Normal file
47
candle-metal-kernels/src/kernels/utils.metal
Normal file
@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC uint nonzero(uint n) {
|
||||
return n == 0 ? 1 : n;
|
||||
}
|
||||
|
||||
template<uint N>
|
||||
constexpr uint nonzero() {
|
||||
return N == 0 ? 1 : N;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr ushort granularity() {
|
||||
return nonzero<vec_elements<T>::value>();
|
||||
}
|
||||
|
||||
METAL_FUNC uint next_p2(uint x) {
|
||||
return 1 << (32 - clz(x - 1));
|
||||
}
|
||||
|
||||
METAL_FUNC uint prev_p2(uint x) {
|
||||
return 1 << (31 - clz(x));
|
||||
}
|
||||
|
||||
constant uint MAX_SHARED_MEM = 32767;
|
||||
|
||||
template<typename T>
|
||||
METAL_FUNC uint max_shared_mem(uint n) {
|
||||
return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T)));
|
||||
}
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant const uint &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
@ -1,30 +1,37 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLGPUFamily, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
mod ffi;
|
||||
mod gpu;
|
||||
use gpu::get_device_core_count;
|
||||
|
||||
mod utils;
|
||||
pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
const AFFINE: &str = include_str!("kernels/affine.metal");
|
||||
const INDEXING: &str = include_str!("kernels/indexing.metal");
|
||||
const UNARY: &str = include_str!("kernels/unary.metal");
|
||||
const BINARY: &str = include_str!("kernels/binary.metal");
|
||||
const TERNARY: &str = include_str!("kernels/ternary.metal");
|
||||
const CAST: &str = include_str!("kernels/cast.metal");
|
||||
const CONV: &str = include_str!("kernels/conv.metal");
|
||||
const REDUCE: &str = include_str!("kernels/reduce.metal");
|
||||
const RANDOM: &str = include_str!("kernels/random.metal");
|
||||
const QUANTIZED: &str = include_str!("kernels/quantized.metal");
|
||||
const SORT: &str = include_str!("kernels/sort.metal");
|
||||
const MFA: &[u8] = include_bytes!("libraries/libMetalFlashAttention.metallib");
|
||||
const CANDLE: &[u8] = include_bytes!("libraries/candle.metallib");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Candle,
|
||||
Affine,
|
||||
Indexing,
|
||||
Unary,
|
||||
@ -200,7 +207,7 @@ impl Kernels {
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Sort => SORT,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
_ => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,14 +223,16 @@ impl Kernels {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
Source::Candle => device.new_library_with_data(CANDLE).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load candle: {e}"
|
||||
))
|
||||
})?,
|
||||
Source::Mfa => device.new_library_with_data(MFA).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?,
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
@ -739,6 +748,69 @@ pub fn call_rms_norm(
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_layer_norm(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
eps: f32,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
alpha: &Buffer,
|
||||
alpha_offset: usize,
|
||||
beta: &Buffer,
|
||||
beta_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output,
|
||||
(alpha, alpha_offset),
|
||||
(beta, beta_offset),
|
||||
eps
|
||||
)
|
||||
);
|
||||
|
||||
let out_length = length / elements_to_sum;
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
@ -1402,6 +1474,29 @@ pub fn call_gemm(
|
||||
rhs_buffer: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let prefer_async_copy = !device.supports_family(MTLGPUFamily::Apple9);
|
||||
|
||||
let mut actual_groups: usize = 1;
|
||||
actual_groups *= divide(m, 48) as usize;
|
||||
actual_groups *= divide(n, 48) as usize;
|
||||
actual_groups *= b;
|
||||
|
||||
let core_count = get_device_core_count(device);
|
||||
let ideal_grouping = if name == "sgemm" {
|
||||
actual_groups <= core_count * 6
|
||||
} else {
|
||||
actual_groups <= core_count * 9
|
||||
};
|
||||
|
||||
let mut blockdim = (32, 32, 32);
|
||||
if !ideal_grouping {
|
||||
if name == "sgemm" {
|
||||
blockdim = (48, 48, 24);
|
||||
} else {
|
||||
blockdim = (48, 48, 32);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(rhs_stride.len() >= 2);
|
||||
assert!(lhs_stride.len() >= 2);
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
@ -1438,50 +1533,45 @@ pub fn call_gemm(
|
||||
let alpha = 1.0f32;
|
||||
let beta = 0.0f32;
|
||||
let batched = b > 1;
|
||||
println!("batched: {batched}");
|
||||
let fused_activation = false;
|
||||
let fused_bias = false;
|
||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||
let m_simd = 8;
|
||||
let n_simd = 8;
|
||||
let k_simd = 64;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
} else {
|
||||
let m_simd = 40;
|
||||
let n_simd = 40;
|
||||
let k_simd = 32;
|
||||
let m_splits = 1;
|
||||
let n_splits = 1;
|
||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||
};
|
||||
|
||||
let constants = Some(ConstantValues::new(vec![
|
||||
(0, Value::USize(m)),
|
||||
(1, Value::USize(n)),
|
||||
(2, Value::USize(k)),
|
||||
(10, Value::Bool(a_trans)),
|
||||
(11, Value::Bool(b_trans)),
|
||||
(13, Value::Bool(d_trans)),
|
||||
(20, Value::F32(alpha)),
|
||||
(21, Value::F32(beta)),
|
||||
//(13, Value::Bool(d_trans)),
|
||||
//(20, Value::F32(alpha)),
|
||||
//(21, Value::F32(beta)),
|
||||
(100, Value::Bool(batched)),
|
||||
(101, Value::Bool(fused_activation)),
|
||||
//(101, Value::Bool(fused_activation)),
|
||||
// Garbage
|
||||
(102, Value::Bool(false)),
|
||||
(103, Value::Bool(false)),
|
||||
(113, Value::Bool(false)),
|
||||
(50_000, Value::Bool(false)),
|
||||
// End garbage
|
||||
(200, Value::U16(m_simd)),
|
||||
(201, Value::U16(n_simd)),
|
||||
(202, Value::U16(k_simd)),
|
||||
(210, Value::U16(m_splits)),
|
||||
(211, Value::U16(n_splits)),
|
||||
(50_001, Value::Bool(fused_bias)),
|
||||
//(200, Value::U16(blockdim.0)),
|
||||
//(201, Value::U16(blockdim.1)),
|
||||
//(202, Value::U16(blockdim.2)),
|
||||
(206, Value::Bool(prefer_async_copy)),
|
||||
(207, Value::Bool(ideal_grouping)),
|
||||
//(210, Value::U16(m_splits)),
|
||||
//(211, Value::U16(n_splits)),
|
||||
//(50_001, Value::Bool(fused_bias)),
|
||||
]));
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||
let m_group = m_simd * m_splits;
|
||||
let n_group = n_simd * n_splits;
|
||||
let pipeline = kernels.load_pipeline_with_constants(device, Source::Candle, name, constants)?;
|
||||
|
||||
let m_group: u16 = 32;
|
||||
let n_group: u16 = 32;
|
||||
let m_splits: u16 = 2;
|
||||
let n_splits: u16 = 2;
|
||||
let k_simd: u16 = 32;
|
||||
let m_simd = m_group / m_splits;
|
||||
let n_simd = n_group / n_splits;
|
||||
|
||||
let a_block_length = m_group * k_simd;
|
||||
let b_block_length = k_simd * n_group;
|
||||
@ -1491,6 +1581,7 @@ pub fn call_gemm(
|
||||
let c_block_length = m_group * n_group;
|
||||
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||
}
|
||||
/*
|
||||
if fused_bias {
|
||||
if d_trans {
|
||||
block_elements = std::cmp::max(block_elements, m_group);
|
||||
@ -1498,6 +1589,7 @@ pub fn call_gemm(
|
||||
block_elements = std::cmp::max(block_elements, n_group);
|
||||
}
|
||||
}
|
||||
*/
|
||||
let bytes = match name {
|
||||
"sgemm" => 4,
|
||||
"hgemm" => 2,
|
||||
@ -1511,7 +1603,7 @@ pub fn call_gemm(
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||
encoder.set_threadgroup_memory_length(0, block_bytes as NSUInteger);
|
||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||
encoder.set_buffer(2, Some(output), 0);
|
||||
@ -1525,7 +1617,7 @@ pub fn call_gemm(
|
||||
// TODO byte_stride_d
|
||||
let byte_stride_d = 0;
|
||||
|
||||
let buffer: Vec<u64> = vec![
|
||||
let buffer: [u64; 4] = [
|
||||
byte_stride_a as _,
|
||||
byte_stride_b as _,
|
||||
byte_stride_c as _,
|
||||
@ -1588,6 +1680,39 @@ pub fn call_im2col1d_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_col2im1d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
let l_in = shape[1];
|
||||
let c_out = shape[2];
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let dst_el = shape[0] * c_out * l_out;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
|
||||
);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_im2col_strided(
|
||||
device: &Device,
|
||||
|
BIN
candle-metal-kernels/src/libraries/candle.metallib
Normal file
BIN
candle-metal-kernels/src/libraries/candle.metallib
Normal file
Binary file not shown.
@ -1023,6 +1023,27 @@ fn where_cond() {
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
#[test]
|
||||
fn where_cond_u32_f32() {
|
||||
let shape = vec![6];
|
||||
let cond = vec![0u32, 1, 0, 0, 1, 1];
|
||||
let cond_l = (vec![1], 0);
|
||||
let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let left_l = (vec![1], 0);
|
||||
let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
|
||||
let right_l = (vec![1], 0);
|
||||
let results = run_where_cond(
|
||||
&shape,
|
||||
&cond,
|
||||
cond_l,
|
||||
&left_true,
|
||||
left_l,
|
||||
&right_false,
|
||||
right_l,
|
||||
"where_u32_f32",
|
||||
);
|
||||
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
|
||||
}
|
||||
|
||||
fn run_gemm<T: Clone>(
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
@ -1079,6 +1100,11 @@ fn gemm() {
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
println!("lhs: {lhs:?}");
|
||||
println!("lhs_stride: {lhs_stride:?}");
|
||||
println!("rhs: {rhs:?}");
|
||||
println!("rhs_stride: {rhs_stride:?}");
|
||||
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
@ -1090,6 +1116,11 @@ fn gemm() {
|
||||
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
|
||||
let rhs_stride = vec![n * k, n, 1];
|
||||
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
|
||||
println!("lhs: {lhs:?}");
|
||||
println!("lhs_stride: {lhs_stride:?}");
|
||||
println!("rhs: {rhs:?}");
|
||||
println!("rhs_stride: {rhs_stride:?}");
|
||||
|
||||
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
|
||||
assert_eq!(
|
||||
approx(results, 4),
|
||||
|
@ -5,7 +5,7 @@ use criterion::{black_box, criterion_group, Criterion};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) {
|
||||
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input);
|
||||
let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(input);
|
||||
}
|
||||
|
||||
const B: usize = 1;
|
||||
|
@ -1,30 +1,25 @@
|
||||
use candle::{DType, Device, Result, Shape, Tensor};
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
all_data: Tensor,
|
||||
// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
|
||||
// on the first call where the batch size is easily known.
|
||||
// Also this makes it safe to clone a KvCache that has been reseted (as in it will not share
|
||||
// its internal state with the cloned instance).
|
||||
all_data: Option<Tensor>,
|
||||
dim: usize,
|
||||
current_seq_len: usize,
|
||||
max_seq_len: usize,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||
dim: D,
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||
let max_seq_len = shape.dims()[dim];
|
||||
let all_data = Tensor::zeros(shape, dtype, dev)?;
|
||||
Ok(Self {
|
||||
all_data,
|
||||
pub fn new(dim: usize, max_seq_len: usize) -> Self {
|
||||
Self {
|
||||
all_data: None,
|
||||
dim,
|
||||
current_seq_len: 0,
|
||||
max_seq_len,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dim(&self) -> usize {
|
||||
@ -39,16 +34,34 @@ impl Cache {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
pub fn all_data(&self) -> &Tensor {
|
||||
pub fn all_data(&self) -> &Option<Tensor> {
|
||||
&self.all_data
|
||||
}
|
||||
|
||||
pub fn current_data(&self) -> Result<Tensor> {
|
||||
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
||||
pub fn current_data(&self) -> Result<Option<Tensor>> {
|
||||
let data = match self.all_data.as_ref() {
|
||||
None => None,
|
||||
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
|
||||
};
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.current_seq_len = 0;
|
||||
self.all_data = None;
|
||||
}
|
||||
|
||||
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||
let seq_len = src.dim(self.dim)?;
|
||||
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
|
||||
// self.all_data.get_or_insert_with.
|
||||
if self.all_data.is_none() {
|
||||
let mut shape = src.dims().to_vec();
|
||||
shape[self.dim] = self.max_seq_len;
|
||||
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
|
||||
self.all_data = Some(ad)
|
||||
};
|
||||
let ad = self.all_data.as_mut().unwrap();
|
||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||
candle::bail!(
|
||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
||||
@ -56,8 +69,7 @@ impl Cache {
|
||||
self.max_seq_len
|
||||
)
|
||||
}
|
||||
self.all_data
|
||||
.slice_set(src, self.dim, self.current_seq_len)?;
|
||||
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
||||
self.current_seq_len += seq_len;
|
||||
Ok(())
|
||||
}
|
||||
@ -70,32 +82,66 @@ pub struct KvCache {
|
||||
}
|
||||
|
||||
impl KvCache {
|
||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||
dim: D,
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||
let k = Cache::new(dim, &shape, dtype, dev)?;
|
||||
let v = Cache::new(dim, &shape, dtype, dev)?;
|
||||
Ok(Self { k, v })
|
||||
pub fn new(dim: usize, max_seq_len: usize) -> Self {
|
||||
let k = Cache::new(dim, max_seq_len);
|
||||
let v = Cache::new(dim, max_seq_len);
|
||||
Self { k, v }
|
||||
}
|
||||
|
||||
pub fn k(&self) -> Result<Tensor> {
|
||||
pub fn k_cache(&self) -> &Cache {
|
||||
&self.k
|
||||
}
|
||||
|
||||
pub fn v_cache(&self) -> &Cache {
|
||||
&self.v
|
||||
}
|
||||
|
||||
pub fn k_cache_mut(&mut self) -> &mut Cache {
|
||||
&mut self.k
|
||||
}
|
||||
|
||||
pub fn v_cache_mut(&mut self) -> &mut Cache {
|
||||
&mut self.v
|
||||
}
|
||||
|
||||
pub fn k(&self) -> Result<Option<Tensor>> {
|
||||
self.k.current_data()
|
||||
}
|
||||
|
||||
pub fn v(&self) -> Result<Tensor> {
|
||||
pub fn v(&self) -> Result<Option<Tensor>> {
|
||||
self.v.current_data()
|
||||
}
|
||||
|
||||
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
self.k.append(k)?;
|
||||
self.v.append(v)?;
|
||||
let k = self.k.current_data()?;
|
||||
let v = self.v.current_data()?;
|
||||
let out_k = self.k.current_data()?;
|
||||
let out_v = self.v.current_data()?;
|
||||
let k = match out_k {
|
||||
None => {
|
||||
let mut shape = k.dims().to_vec();
|
||||
shape[self.k.dim] = 0;
|
||||
Tensor::zeros(shape, k.dtype(), k.device())?
|
||||
}
|
||||
Some(k) => k,
|
||||
};
|
||||
let v = match out_v {
|
||||
None => {
|
||||
let mut shape = v.dims().to_vec();
|
||||
shape[self.k.dim] = 0;
|
||||
Tensor::zeros(shape, v.dtype(), v.device())?
|
||||
}
|
||||
Some(v) => v,
|
||||
};
|
||||
Ok((k, v))
|
||||
}
|
||||
|
||||
pub fn current_seq_len(&self) -> usize {
|
||||
self.k.current_seq_len()
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.k.reset();
|
||||
self.v.reset();
|
||||
}
|
||||
}
|
||||
|
@ -11,8 +11,8 @@
|
||||
//! use candle_nn::{LayerNorm, Module};
|
||||
//! # fn main() -> candle::Result<()> {
|
||||
//!
|
||||
//! let w = Tensor::new(1f32, &Cpu)?;
|
||||
//! let b = Tensor::new(0f32, &Cpu)?;
|
||||
//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?;
|
||||
//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?;
|
||||
//! let layer = LayerNorm::new(w, b, 1e-5);
|
||||
//!
|
||||
//! let xs = Tensor::new(
|
||||
@ -107,6 +107,11 @@ impl LayerNorm {
|
||||
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
if x.is_contiguous() && self.remove_mean {
|
||||
if let Some(bias) = self.bias.as_ref() {
|
||||
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
|
||||
}
|
||||
}
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
||||
let xs = xs.chunk(2, D::Minus1)?;
|
||||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(candle::D::Minus1)?;
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
||||
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
|
||||
}
|
||||
|
||||
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
let hidden_size_xs = xs.dim(candle::D::Minus1)?;
|
||||
let hidden_size_xs = xs.dim(D::Minus1)?;
|
||||
let hidden_size_alpha = alpha.dims1()?;
|
||||
if hidden_size_xs != hidden_size_alpha {
|
||||
candle::bail!(
|
||||
@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerNorm {
|
||||
eps: f32,
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for LayerNorm {
|
||||
fn name(&self) -> &'static str {
|
||||
"layer-norm"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
|
||||
let eps = self.eps;
|
||||
fn inner<
|
||||
T: candle::WithDType
|
||||
+ num_traits::Float
|
||||
+ num_traits::AsPrimitive<f32>
|
||||
+ num_traits::FromPrimitive,
|
||||
>(
|
||||
src: &[T],
|
||||
layout: &Layout,
|
||||
alpha: &[T],
|
||||
alpha_layout: &Layout,
|
||||
beta: &[T],
|
||||
beta_layout: &Layout,
|
||||
eps: f32,
|
||||
) -> Result<(CpuStorage, Shape)> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let alpha = match alpha_layout.contiguous_offsets() {
|
||||
None => candle::bail!("alpha has to be contiguous"),
|
||||
Some((o1, o2)) => &alpha[o1..o2],
|
||||
};
|
||||
let beta = match beta_layout.contiguous_offsets() {
|
||||
None => candle::bail!("beta has to be contiguous"),
|
||||
Some((o1, o2)) => &beta[o1..o2],
|
||||
};
|
||||
let el_count = layout.shape().elem_count();
|
||||
let dims = layout.shape().dims();
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let mut dst = vec![T::zero(); el_count];
|
||||
src.par_chunks(dim_m1)
|
||||
.zip(dst.par_chunks_mut(dim_m1))
|
||||
.for_each(|(src, dst)| {
|
||||
let mut sum = 0f32;
|
||||
let mut sum2 = 0f32;
|
||||
for v in src {
|
||||
let v = v.as_();
|
||||
sum += v;
|
||||
sum2 += v * v;
|
||||
}
|
||||
let mean = sum / dim_m1 as f32;
|
||||
let var = sum2 / dim_m1 as f32 - mean * mean;
|
||||
let inv_std = (var + eps).sqrt().recip();
|
||||
for ((d, s), (alpha, beta)) in
|
||||
dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
|
||||
{
|
||||
let alpha = alpha.as_();
|
||||
let beta = beta.as_();
|
||||
let d_ = (s.as_() - mean) * inv_std * alpha + beta;
|
||||
*d = T::from_f32(d_).unwrap_or_else(T::nan);
|
||||
}
|
||||
});
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, Shape::from_dims(dims)))
|
||||
}
|
||||
|
||||
use CpuStorage as C;
|
||||
match (s1, s2, s3) {
|
||||
(C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
|
||||
inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
|
||||
}
|
||||
(C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
|
||||
(C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
|
||||
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
s1: &candle::CudaStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::CudaStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::CudaStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
struct S {
|
||||
eps: f32,
|
||||
}
|
||||
impl Map3 for S {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
layout: &Layout,
|
||||
alpha: &CudaSlice<T>,
|
||||
alpha_layout: &Layout,
|
||||
beta: &CudaSlice<T>,
|
||||
beta_layout: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let alpha = match alpha_layout.contiguous_offsets() {
|
||||
None => candle::bail!("alpha has to be contiguous"),
|
||||
Some((o1, o2)) => alpha.slice(o1..o2),
|
||||
};
|
||||
let beta = match beta_layout.contiguous_offsets() {
|
||||
None => candle::bail!("beta has to be contiguous"),
|
||||
Some((o1, o2)) => beta.slice(o1..o2),
|
||||
};
|
||||
let el = layout.shape().elem_count();
|
||||
let dims = layout.shape().dims();
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
||||
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (n_rows as u32, 1, 1),
|
||||
block_dim: (1024, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
let dev = s1.device();
|
||||
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, l1.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
s1: &candle::MetalStorage,
|
||||
l1: &Layout,
|
||||
s2: &candle::MetalStorage,
|
||||
l2: &Layout,
|
||||
s3: &candle::MetalStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
let device = s1.device();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
|
||||
(DType::F32, DType::F32, DType::F32) => "layernorm_f32",
|
||||
(DType::F16, DType::F16, DType::F16) => "layernorm_f16",
|
||||
(DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
|
||||
(dt1, dt2, dt3) => {
|
||||
candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
|
||||
}
|
||||
};
|
||||
|
||||
if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
|
||||
candle::bail!("Non contiguous layernorm is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = l1.dims()[l1.shape().rank() - 1];
|
||||
let elem_count = l1.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
|
||||
candle_metal_kernels::call_layer_norm(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
self.eps,
|
||||
s1.buffer(),
|
||||
l1.start_offset() * s1.dtype().size_in_bytes(),
|
||||
s2.buffer(),
|
||||
l2.start_offset() * s2.dtype().size_in_bytes(),
|
||||
s3.buffer(),
|
||||
l3.start_offset() * s3.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
|
||||
Ok((newstorage, l1.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = {
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(alpha)?
|
||||
.broadcast_add(beta)
|
||||
}
|
||||
|
||||
pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
||||
let hidden_size_xs = xs.dim(D::Minus1)?;
|
||||
let hidden_size_alpha = alpha.dims1()?;
|
||||
let hidden_size_beta = beta.dims1()?;
|
||||
if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
|
||||
candle::bail!(
|
||||
"shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
|
||||
xs.shape(),
|
||||
alpha.shape(),
|
||||
beta.shape()
|
||||
)
|
||||
}
|
||||
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
|
||||
}
|
||||
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
||||
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
||||
let (b_size, c, h, w) = xs.dims4()?;
|
||||
@ -678,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
|
||||
n => candle::bail!("replication-pad with a size of {n} is not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Identity;
|
||||
|
||||
impl Identity {
|
||||
pub fn new() -> Identity {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Identity {
|
||||
fn default() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Identity {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
Ok(xs.clone())
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,12 @@ fn layer_norm() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
let w = Tensor::new(&[3f32], device)?;
|
||||
let b = Tensor::new(&[0.5f32], device)?;
|
||||
let ln2 = LayerNorm::new(Tensor::cat(&[&w, &w], 0)?, Tensor::cat(&[&b, &b], 0)?, 1e-8);
|
||||
let ln3 = LayerNorm::new(
|
||||
Tensor::cat(&[&w, &w, &w], 0)?,
|
||||
Tensor::cat(&[&b, &b, &b], 0)?,
|
||||
1e-8,
|
||||
);
|
||||
let ln = LayerNorm::new(w, b, 1e-8);
|
||||
|
||||
let two = Tensor::new(&[[[2f32]]], device)?;
|
||||
@ -20,11 +26,11 @@ fn layer_norm() -> Result<()> {
|
||||
assert_eq!(res.to_vec1::<f32>()?, [0.5f32]);
|
||||
|
||||
let inp = Tensor::new(&[[[4f32, 0f32]]], device)?;
|
||||
let res = ln.forward(&inp)?;
|
||||
let res = ln2.forward(&inp)?;
|
||||
assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]);
|
||||
|
||||
let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
|
||||
let res = ln.forward(&inp)?;
|
||||
let res = ln3.forward(&inp)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res, 4)?,
|
||||
[[
|
||||
@ -35,7 +41,10 @@ fn layer_norm() -> Result<()> {
|
||||
);
|
||||
let mean = (res.sum_keepdim(2)? / 3.0)?;
|
||||
// The average value should be `b`.
|
||||
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&mean, 4)?,
|
||||
[[[0.5], [0.5], [0.5]]]
|
||||
);
|
||||
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
|
||||
// The standard deviation should be sqrt(`w`).
|
||||
assert_eq!(
|
||||
|
@ -77,6 +77,32 @@ fn rms_norm(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn layer_norm(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
|
||||
let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?;
|
||||
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(&t, 4)?,
|
||||
&[
|
||||
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
|
||||
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
|
||||
]
|
||||
);
|
||||
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(&t2, 4)?,
|
||||
&[
|
||||
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
|
||||
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
|
||||
]
|
||||
);
|
||||
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert!(diff < 1e-5);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax_numerical_stability() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
@ -185,4 +211,5 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
|
||||
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.5.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.5.1" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.6.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.6.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::{collections::HashMap, usize};
|
||||
|
||||
@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
|
||||
DataType::Float16 => Some(DType::F16),
|
||||
DataType::Float => Some(DType::F32),
|
||||
DataType::Double => Some(DType::F64),
|
||||
DataType::Bool => Some(DType::U8),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -56,6 +57,15 @@ impl Attr for str {
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for GraphProto {
|
||||
const TYPE: AttributeType = AttributeType::Graph;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
attr.g
|
||||
.as_ref()
|
||||
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
@ -214,13 +224,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
// anymore.
|
||||
pub fn simple_eval(
|
||||
model: &onnx::ModelProto,
|
||||
inputs: HashMap<String, Value>,
|
||||
mut inputs: HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
let graph = match &model.graph {
|
||||
None => bail!("no graph defined in proto"),
|
||||
Some(graph) => graph,
|
||||
};
|
||||
let mut values = inputs;
|
||||
simple_eval_(graph, &mut inputs)
|
||||
}
|
||||
|
||||
fn simple_eval_(
|
||||
graph: &onnx::GraphProto,
|
||||
values: &mut HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
for t in graph.initializer.iter() {
|
||||
let tensor = get_tensor(t, t.name.as_str())?;
|
||||
values.insert(t.name.to_string(), tensor);
|
||||
@ -627,6 +643,13 @@ pub fn simple_eval(
|
||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||
values.insert(node.output[0].clone(), dims);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Size
|
||||
"Size" => {
|
||||
let data = get(&node.input[0])?;
|
||||
let size: usize = data.dims().iter().product();
|
||||
let output = Tensor::from_slice(&[size as i64], (), data.device())?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
|
||||
"Sqrt" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
@ -877,6 +900,16 @@ pub fn simple_eval(
|
||||
let output = input.relu()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Ceil" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.ceil()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Floor" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.floor()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
|
||||
"Constant" => {
|
||||
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
|
||||
@ -948,6 +981,165 @@ pub fn simple_eval(
|
||||
let input = get(&node.input[0])?;
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
|
||||
"If" => {
|
||||
// protobuf encodes boolean false as 0 and true as 1
|
||||
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
|
||||
let attr_name = if cond != 0 {
|
||||
"then_branch"
|
||||
} else {
|
||||
"else_branch"
|
||||
};
|
||||
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
|
||||
if sub_graph.output.len() != node.output.len() {
|
||||
bail!(
|
||||
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
|
||||
node.name,
|
||||
sub_graph.output.len(),
|
||||
node.output.len()
|
||||
);
|
||||
}
|
||||
let branch_out = simple_eval_(sub_graph, values)?;
|
||||
for (i, out) in node.output.iter().enumerate() {
|
||||
values.insert(
|
||||
out.clone(),
|
||||
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
|
||||
"Pad" => {
|
||||
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
|
||||
let data = get(&node.input[0])?;
|
||||
let pads = get(&node.input[1])?;
|
||||
if node.input.len() > 2 {
|
||||
bail!(
|
||||
"unsupported number of inputs {} for Pad node {:?}, expected 2",
|
||||
node.input.len(),
|
||||
node.name
|
||||
);
|
||||
}
|
||||
if pads.rank() != 1 {
|
||||
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
|
||||
}
|
||||
if pads.dim(0).unwrap() != 2 * data.rank() {
|
||||
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
|
||||
}
|
||||
|
||||
let pads = pads.to_vec1::<i64>()?;
|
||||
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
|
||||
|
||||
match mode {
|
||||
"reflect" => {
|
||||
let mut out = data.clone();
|
||||
for (i, &dim) in data.dims().iter().enumerate().rev() {
|
||||
if pads_pre[i] == 0 && pads_post[i] == 0 {
|
||||
continue;
|
||||
}
|
||||
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
|
||||
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
|
||||
}
|
||||
let idx = if dim > 1 {
|
||||
let cycle_len = dim * 2 - 1;
|
||||
let skip = (pads_pre[i] as usize) % cycle_len;
|
||||
let idx = zigzag(0, (dim - 1) as i64)
|
||||
.skip(skip)
|
||||
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
|
||||
Tensor::from_iter(idx, out.device())?
|
||||
} else {
|
||||
Tensor::full(0i64, (dim,), out.device())?
|
||||
};
|
||||
|
||||
out = out.index_select(&idx, i)?;
|
||||
}
|
||||
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
_ => bail!(
|
||||
"unsupported 'mode' value {mode:?} for Pad node {:?}",
|
||||
node.name
|
||||
),
|
||||
}
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
|
||||
"Slice" => {
|
||||
let data = get(&node.input[0])?;
|
||||
let starts = get(&node.input[1])?;
|
||||
let ends = get(&node.input[2])?;
|
||||
let default_axes;
|
||||
let default_steps;
|
||||
let axes: &Tensor;
|
||||
let steps: &Tensor;
|
||||
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
|
||||
// they are set to [1, ..., 1] of length len(starts)
|
||||
match node.input.len() {
|
||||
3 => {
|
||||
let len = starts.dims()[0];
|
||||
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
|
||||
axes = default_axes.as_ref().unwrap();
|
||||
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
|
||||
steps = default_steps.as_ref().unwrap();
|
||||
}
|
||||
4 => {
|
||||
let len = starts.dims()[0];
|
||||
axes = get(&node.input[3])?;
|
||||
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
|
||||
steps = default_steps.as_ref().unwrap();
|
||||
}
|
||||
5 => {
|
||||
steps = get(&node.input[4])?;
|
||||
axes = get(&node.input[3])?;
|
||||
}
|
||||
_ => bail!(
|
||||
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
|
||||
node.input.len(),
|
||||
node
|
||||
),
|
||||
}
|
||||
|
||||
let mut out = data.clone();
|
||||
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
|
||||
// All negative elements of axes are made non-negative by
|
||||
// adding r to them, where r = rank(input).
|
||||
let axis = if axis < 0 {
|
||||
axis + data.rank() as i64
|
||||
} else {
|
||||
axis
|
||||
} as usize;
|
||||
|
||||
let data_dim = data.dims()[axis] as i64;
|
||||
let mut s = starts.get(i)?.to_scalar::<i64>()?;
|
||||
let mut e = ends.get(i)?.to_scalar::<i64>()?;
|
||||
// All negative values in starts[i] and ends[i] have
|
||||
// dims[axes[i]] added to them, where dims are the
|
||||
// dimensions of input.
|
||||
if s < 0 {
|
||||
s += data_dim;
|
||||
}
|
||||
if e < 0 {
|
||||
e += data_dim;
|
||||
}
|
||||
|
||||
let p = steps.get(i)?.to_scalar::<i64>()?;
|
||||
// starts[i] is clamped into the range [0, dims[axes[i]]]
|
||||
// for positive stepping and [0, dims[axes[i]]-1] for
|
||||
// negative stepping.
|
||||
// for positive stepping ends[axes[i]] is clamped to
|
||||
// [0, dims[axes[i]]], while for negative stepping it is
|
||||
// clamped to [-1, dims[axes[i]]-1].
|
||||
if p >= 0 {
|
||||
s = s.clamp(0, data_dim);
|
||||
e = e.clamp(0, data_dim);
|
||||
} else {
|
||||
s = s.clamp(0, data_dim - 1);
|
||||
e = e.clamp(-1, data_dim - 1);
|
||||
}
|
||||
|
||||
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
||||
out = out.index_select(&indexes, axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
@ -1017,6 +1209,102 @@ pub fn simple_eval(
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"ArgMin" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
|
||||
let rank_i64: i64 = input.rank().try_into().unwrap();
|
||||
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
|
||||
bail!(
|
||||
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
|
||||
axis_i64,
|
||||
-rank_i64,
|
||||
rank_i64 - 1
|
||||
)
|
||||
}
|
||||
let axis = input.normalize_axis(axis_i64)?;
|
||||
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
|
||||
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
if select_last_index == 1 {
|
||||
bail!("select_last_index for ArgMin is currently not supported")
|
||||
}
|
||||
let output = if keepdims == 1 {
|
||||
input.argmin_keepdim(axis)?
|
||||
} else {
|
||||
input.argmin(axis)?
|
||||
}
|
||||
.to_dtype(DType::I64)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"ArgMax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
|
||||
let rank_i64: i64 = input.rank().try_into().unwrap();
|
||||
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
|
||||
bail!(
|
||||
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
|
||||
axis_i64,
|
||||
-rank_i64,
|
||||
rank_i64 - 1
|
||||
)
|
||||
}
|
||||
let axis = input.normalize_axis(axis_i64)?;
|
||||
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
|
||||
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
if select_last_index == 1 {
|
||||
bail!("select_last_index for ArgMin is currently not supported")
|
||||
}
|
||||
let output = if keepdims == 1 {
|
||||
input.argmax_keepdim(axis)?
|
||||
} else {
|
||||
input.argmax(axis)?
|
||||
}
|
||||
.to_dtype(DType::I64)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"LeakyRelu" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let dt = input.dtype();
|
||||
match dt {
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
bail!(
|
||||
"unsupported dtype {}, only float types are allowed for LeakyRelu",
|
||||
dt.as_str()
|
||||
)
|
||||
}
|
||||
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
|
||||
}
|
||||
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(0.01);
|
||||
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
|
||||
"Gemm" => {
|
||||
let a = get(&node.input[0])?;
|
||||
let b = get(&node.input[1])?;
|
||||
let c = get(&node.input[2])?;
|
||||
|
||||
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(1.0);
|
||||
let beta = get_attr_opt::<f32>(node, "beta")?.copied().unwrap_or(1.0);
|
||||
|
||||
let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?;
|
||||
let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?;
|
||||
|
||||
let trans_a = get_attr_opt::<i64>(node, "transA")?.copied().unwrap_or(0);
|
||||
let trans_b = get_attr_opt::<i64>(node, "transB")?.copied().unwrap_or(0);
|
||||
|
||||
let a = if trans_a == 0 { a.clone() } else { a.t()? };
|
||||
let b = if trans_b == 0 { b.clone() } else { b.t()? };
|
||||
|
||||
let output = a
|
||||
.broadcast_mul(&alpha)?
|
||||
.broadcast_matmul(&b)?
|
||||
.broadcast_add(&c.broadcast_mul(&beta)?)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
404
candle-transformers/src/models/beit.rs
Normal file
404
candle-transformers/src/models/beit.rs
Normal file
@ -0,0 +1,404 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const IMG_SIZE: usize = 384;
|
||||
const PATCH_SIZE: usize = 16;
|
||||
const NUM_CLASSES: usize = 1000;
|
||||
const WINDOW_SIZE: usize = IMG_SIZE / PATCH_SIZE; // 384 / 16 = 24
|
||||
const NB_TOKENS: usize = WINDOW_SIZE * WINDOW_SIZE + 1; // 24 * 24 + 1 = 577
|
||||
|
||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
relative_position_bias_table: Tensor,
|
||||
relative_position_index: Tensor,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
) -> Result<Self> {
|
||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||
// num_relative_distance = token-token(47x47) + token-CLS(1) + CLS-token(1) + CLS-CLS(1) = 2212
|
||||
let num_relative_distance = (2 * WINDOW_SIZE - 1) * (2 * WINDOW_SIZE - 1) + 3;
|
||||
let relative_position_bias_table = vb.get(
|
||||
(num_relative_distance, num_heads),
|
||||
"relative_position_bias_table",
|
||||
)?;
|
||||
let relative_position_index =
|
||||
Self::gen_relative_position_index(relative_position_bias_table.device())?;
|
||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
relative_position_bias_table,
|
||||
relative_position_index,
|
||||
num_heads,
|
||||
scale,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
// See: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/beit.py#L61
|
||||
fn gen_relative_position_index(device: &Device) -> Result<Tensor> {
|
||||
let num_relative_distance = (2 * WINDOW_SIZE - 1) * (2 * WINDOW_SIZE - 1) + 3;
|
||||
let w_area = WINDOW_SIZE * WINDOW_SIZE;
|
||||
|
||||
let t_arange: Tensor = Tensor::arange(0, WINDOW_SIZE as u32, device)?;
|
||||
let t_ndgrid = Tensor::meshgrid(&[&t_arange, &t_arange], false)?;
|
||||
let coords_flatten = Tensor::stack(&t_ndgrid, 0)?.flatten(1, 2)?;
|
||||
|
||||
let tmp1 = coords_flatten
|
||||
.unsqueeze(2)?
|
||||
.broadcast_as((2, w_area, w_area))?
|
||||
.to_dtype(DType::I64)?;
|
||||
let tmp2 = coords_flatten
|
||||
.unsqueeze(1)?
|
||||
.broadcast_as((2, w_area, w_area))?
|
||||
.to_dtype(DType::I64)?;
|
||||
let relative_coords = (tmp1 - tmp2)?
|
||||
.transpose(0, 1)? // 102
|
||||
.transpose(1, 2)? // 120
|
||||
.contiguous()?;
|
||||
|
||||
let relative_coords = relative_coords.slice_assign(
|
||||
&[0..w_area, 0..w_area, 0..1],
|
||||
&(relative_coords.i((0..w_area, 0..w_area, 0..1))? + (WINDOW_SIZE - 1) as f64)?,
|
||||
)?;
|
||||
let relative_coords = relative_coords.slice_assign(
|
||||
&[0..w_area, 0..w_area, 1..2],
|
||||
&(relative_coords.i((0..w_area, 0..w_area, 1..2))? + (WINDOW_SIZE - 1) as f64)?,
|
||||
)?;
|
||||
let relative_coords = relative_coords.slice_assign(
|
||||
&[0..w_area, 0..w_area, 0..1],
|
||||
&(relative_coords.i((.., .., 0..1))? * (2. * (WINDOW_SIZE as f64) - 1.))?,
|
||||
)?;
|
||||
|
||||
Tensor::zeros((w_area + 1, w_area + 1), DType::I64, device)?
|
||||
.slice_assign(&[1.., 1..], &relative_coords.sum(2)?)?
|
||||
.slice_assign(
|
||||
&[0..1, 0..(w_area + 1)],
|
||||
&(Tensor::ones((1, w_area + 1), DType::I64, device)?
|
||||
* ((num_relative_distance - 3) as f64))?
|
||||
.to_dtype(DType::I64)?,
|
||||
)?
|
||||
.slice_assign(
|
||||
&[0..(w_area + 1), 0..1],
|
||||
&(Tensor::ones((w_area + 1, 1), DType::I64, device)?
|
||||
* ((num_relative_distance - 2) as f64))?
|
||||
.to_dtype(DType::I64)?,
|
||||
)?
|
||||
.slice_assign(
|
||||
&[0..1, 0..1],
|
||||
&(Tensor::ones((1, 1), DType::I64, device)?
|
||||
* ((num_relative_distance - 1) as f64))?
|
||||
.to_dtype(DType::I64)?,
|
||||
)
|
||||
}
|
||||
|
||||
fn _get_rel_pos_bias(&self) -> Result<Tensor> {
|
||||
self.relative_position_bias_table
|
||||
.index_select(
|
||||
&self
|
||||
.relative_position_index
|
||||
.flatten_all()?
|
||||
.to_dtype(DType::U32)?,
|
||||
0,
|
||||
)?
|
||||
.reshape((NB_TOKENS, NB_TOKENS, ()))?
|
||||
.transpose(0, 1)? // 102
|
||||
.transpose(0, 2)? // 201
|
||||
.contiguous()?
|
||||
.unsqueeze(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)? // 02134
|
||||
.transpose(0, 1)? // 20134
|
||||
.transpose(2, 3)?; // 20314
|
||||
let q = (qkv.i(0)? * self.scale)?;
|
||||
let k = qkv.i(1)?.contiguous()?;
|
||||
let v = qkv.i(2)?.contiguous()?;
|
||||
let attn = (&q.matmul(&k.t()?)? + self._get_rel_pos_bias())?;
|
||||
let attn = candle_nn::ops::softmax(&attn, D::Minus1)?;
|
||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LayerScale {
|
||||
gamma: Tensor,
|
||||
}
|
||||
|
||||
impl LayerScale {
|
||||
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
|
||||
let gamma = vb.get(dim, "gamma")?;
|
||||
Ok(Self { gamma })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&self.gamma)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
|
||||
let out_features = in_features;
|
||||
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
|
||||
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.gelu()?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
ls1: LayerScale,
|
||||
norm2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
ls2: LayerScale,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
||||
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
ls1,
|
||||
norm2,
|
||||
mlp,
|
||||
ls2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self
|
||||
.ls1
|
||||
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = self
|
||||
.ls2
|
||||
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
patch_size: (usize, usize),
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(vb: VarBuilder, patch_size: usize, in_chans: usize, embed_dim: usize) -> Result<Self> {
|
||||
let config = candle_nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
|
||||
Ok(Self {
|
||||
proj,
|
||||
patch_size: (patch_size, patch_size),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _c, h, w) = xs.dims4()?;
|
||||
let (patch_h, patch_w) = self.patch_size;
|
||||
if (h % patch_h) != 0 {
|
||||
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
|
||||
}
|
||||
if (w % patch_w) != 0 {
|
||||
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
|
||||
}
|
||||
let xs = self.proj.forward(xs)?;
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
// flatten embeddings.
|
||||
xs.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BeitVisionTransformer {
|
||||
patch_embed: PatchEmbed,
|
||||
cls_token: Tensor,
|
||||
blocks: Vec<Block>,
|
||||
norm: LayerNorm,
|
||||
head: Linear,
|
||||
}
|
||||
|
||||
impl BeitVisionTransformer {
|
||||
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(vb.pp("patch_embed"), PATCH_SIZE, 3, embed_dim)?;
|
||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||
let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
|
||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let blocks = (0..depth)
|
||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
cls_token,
|
||||
blocks,
|
||||
norm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
Tensor::cat(&[&self.cls_token, &xs], 1)
|
||||
}
|
||||
|
||||
fn get_intermediate_layers_not_chunked(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
blocks_to_take: &[usize],
|
||||
) -> Result<Vec<Tensor>> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
let mut output = Vec::new();
|
||||
for (i, blk) in self.blocks.iter().enumerate() {
|
||||
xs = blk.forward(&xs)?;
|
||||
if blocks_to_take.contains(&i) {
|
||||
output.push(xs.clone());
|
||||
}
|
||||
}
|
||||
if output.len() != blocks_to_take.len() {
|
||||
candle::bail!(
|
||||
"only {} / {} blocks found",
|
||||
output.len(),
|
||||
blocks_to_take.len()
|
||||
);
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn get_intermediate_layers(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
blocks_to_take: &[usize],
|
||||
reshape: bool,
|
||||
return_class_token: bool,
|
||||
norm: bool,
|
||||
) -> Result<Tensor> {
|
||||
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
|
||||
let outputs = if norm {
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| self.norm.forward(out))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
let class_tokens = outputs
|
||||
.iter()
|
||||
.map(|out| out.i((.., 0)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.map(|out| out.i((.., 1..)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let outputs = if reshape {
|
||||
let (b, _c, w, h) = xs.dims4()?;
|
||||
let patch_size = self.patch_embed.patch_size.0;
|
||||
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| {
|
||||
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
|
||||
let outputs = if return_class_token {
|
||||
outputs
|
||||
.iter()
|
||||
.zip(class_tokens.iter())
|
||||
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
|
||||
Tensor::stack(&outputs[..], 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BeitVisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
for blk in self.blocks.iter() {
|
||||
xs = blk.forward(&xs)?
|
||||
}
|
||||
let xs_moy_local_tokens = xs.i((.., 1..))?.mean(1)?;
|
||||
let xs_norm = self.norm.forward(&xs_moy_local_tokens)?;
|
||||
self.head.forward(&xs_norm)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_base(vb: VarBuilder) -> Result<BeitVisionTransformer> {
|
||||
BeitVisionTransformer::new(vb, 12, 768, 12)
|
||||
}
|
||||
|
||||
pub fn vit_large(vb: VarBuilder) -> Result<BeitVisionTransformer> {
|
||||
BeitVisionTransformer::new(vb, 24, 1024, 16)
|
||||
}
|
@ -262,6 +262,20 @@ impl ClipEncoder {
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
// required by LLaVA
|
||||
pub fn output_hidden_states(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
causal_attention_mask: Option<&Tensor>,
|
||||
) -> Result<Vec<Tensor>> {
|
||||
let mut xs = xs.clone();
|
||||
let mut hidden_states = Vec::new();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||
hidden_states.push(xs.clone());
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
/// A CLIP transformer based model.
|
||||
|
@ -46,6 +46,19 @@ impl ClipVisionConfig {
|
||||
patch_size: 32,
|
||||
}
|
||||
}
|
||||
pub fn clip_vit_large_patch14_336() -> Self {
|
||||
Self {
|
||||
embed_dim: 1024,
|
||||
activation: Activation::QuickGelu,
|
||||
intermediate_size: 4096,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
projection_dim: 768,
|
||||
num_channels: 3,
|
||||
image_size: 336,
|
||||
patch_size: 14,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
|
||||
@ -130,6 +143,17 @@ impl ClipVisionTransformer {
|
||||
pre_layer_norm,
|
||||
})
|
||||
}
|
||||
// required by LLaVA
|
||||
pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let hidden_states = pixel_values
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||
let encoder_outputs = result.last().unwrap();
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ClipVisionTransformer {
|
||||
|
553
candle-transformers/src/models/depth_anything_v2.rs
Normal file
553
candle-transformers/src/models/depth_anything_v2.rs
Normal file
@ -0,0 +1,553 @@
|
||||
use candle::D::Minus1;
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::ops::Identity;
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
|
||||
BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
|
||||
};
|
||||
|
||||
use crate::models::dinov2::DinoVisionTransformer;
|
||||
|
||||
pub struct DepthAnythingV2Config {
|
||||
out_channel_sizes: [usize; 4],
|
||||
in_channel_size: usize, // embed_dim in the Dino model
|
||||
num_features: usize,
|
||||
use_batch_norm: bool,
|
||||
use_class_token: bool,
|
||||
layer_ids_vits: Vec<usize>,
|
||||
input_image_size: usize,
|
||||
target_patch_size: usize,
|
||||
}
|
||||
|
||||
impl DepthAnythingV2Config {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
out_channel_sizes: [usize; 4],
|
||||
in_channel_size: usize,
|
||||
num_features: usize,
|
||||
use_batch_norm: bool,
|
||||
use_class_token: bool,
|
||||
layer_ids_vits: Vec<usize>,
|
||||
input_image_size: usize,
|
||||
target_patch_size: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
out_channel_sizes,
|
||||
in_channel_size,
|
||||
num_features,
|
||||
use_batch_norm,
|
||||
use_class_token,
|
||||
layer_ids_vits,
|
||||
input_image_size,
|
||||
target_patch_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_small() -> Self {
|
||||
Self {
|
||||
out_channel_sizes: [48, 96, 192, 384],
|
||||
in_channel_size: 384,
|
||||
num_features: 64,
|
||||
use_batch_norm: false,
|
||||
use_class_token: false,
|
||||
layer_ids_vits: vec![2, 5, 8, 11],
|
||||
input_image_size: 518,
|
||||
target_patch_size: 518 / 14,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_base() -> Self {
|
||||
Self {
|
||||
out_channel_sizes: [96, 192, 384, 768],
|
||||
in_channel_size: 768,
|
||||
num_features: 128,
|
||||
use_batch_norm: false,
|
||||
use_class_token: false,
|
||||
layer_ids_vits: vec![2, 5, 8, 11],
|
||||
input_image_size: 518,
|
||||
target_patch_size: 518 / 14,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_large() -> Self {
|
||||
Self {
|
||||
out_channel_sizes: [256, 512, 1024, 1024],
|
||||
in_channel_size: 1024,
|
||||
num_features: 256,
|
||||
use_batch_norm: false,
|
||||
use_class_token: false,
|
||||
layer_ids_vits: vec![4, 11, 17, 23],
|
||||
input_image_size: 518,
|
||||
target_patch_size: 518 / 14,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_giant() -> Self {
|
||||
Self {
|
||||
out_channel_sizes: [1536, 1536, 1536, 1536],
|
||||
in_channel_size: 1536,
|
||||
num_features: 384,
|
||||
use_batch_norm: false,
|
||||
use_class_token: false,
|
||||
layer_ids_vits: vec![9, 19, 29, 39],
|
||||
input_image_size: 518,
|
||||
target_patch_size: 518 / 14,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResidualConvUnit {
|
||||
activation: Activation,
|
||||
conv1: Conv2d,
|
||||
conv2: Conv2d,
|
||||
batch_norm1: Option<BatchNorm>,
|
||||
batch_norm2: Option<BatchNorm>,
|
||||
}
|
||||
|
||||
impl ResidualConvUnit {
|
||||
pub fn new(
|
||||
conf: &DepthAnythingV2Config,
|
||||
activation: Activation,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
const KERNEL_SIZE: usize = 3;
|
||||
let conv_cfg = Conv2dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
};
|
||||
let conv1 = conv2d(
|
||||
conf.num_features,
|
||||
conf.num_features,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("conv1"),
|
||||
)?;
|
||||
let conv2 = conv2d(
|
||||
conf.num_features,
|
||||
conf.num_features,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("conv2"),
|
||||
)?;
|
||||
|
||||
let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
|
||||
true => {
|
||||
let batch_norm_cfg = BatchNormConfig {
|
||||
eps: 1e-05,
|
||||
remove_mean: false,
|
||||
affine: true,
|
||||
momentum: 0.1,
|
||||
};
|
||||
(
|
||||
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
|
||||
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
|
||||
)
|
||||
}
|
||||
false => (None, None),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
activation,
|
||||
conv1,
|
||||
conv2,
|
||||
batch_norm1,
|
||||
batch_norm2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ResidualConvUnit {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let out = self.activation.forward(xs)?;
|
||||
let out = self.conv1.forward(&out)?;
|
||||
let out = if let Some(batch_norm1) = &self.batch_norm1 {
|
||||
batch_norm1.forward_train(&out)?
|
||||
} else {
|
||||
out
|
||||
};
|
||||
|
||||
let out = self.activation.forward(&out)?;
|
||||
let out = self.conv2.forward(&out)?;
|
||||
let out = if let Some(batch_norm2) = &self.batch_norm2 {
|
||||
batch_norm2.forward_train(&out)?
|
||||
} else {
|
||||
out
|
||||
};
|
||||
|
||||
out + xs
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FeatureFusionBlock {
|
||||
res_conv_unit1: ResidualConvUnit,
|
||||
res_conv_unit2: ResidualConvUnit,
|
||||
output_conv: Conv2d,
|
||||
target_patch_size: usize,
|
||||
}
|
||||
|
||||
impl FeatureFusionBlock {
|
||||
pub fn new(
|
||||
conf: &DepthAnythingV2Config,
|
||||
target_patch_size: usize,
|
||||
activation: Activation,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
const KERNEL_SIZE: usize = 1;
|
||||
let conv_cfg = Conv2dConfig {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
};
|
||||
let output_conv = conv2d(
|
||||
conf.num_features,
|
||||
conf.num_features,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("out_conv"),
|
||||
)?;
|
||||
let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
|
||||
let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
|
||||
|
||||
Ok(Self {
|
||||
res_conv_unit1,
|
||||
res_conv_unit2,
|
||||
output_conv,
|
||||
target_patch_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FeatureFusionBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let out = self.res_conv_unit2.forward(xs)?;
|
||||
let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
|
||||
|
||||
self.output_conv.forward(&out)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Scratch {
|
||||
layer1_rn: Conv2d,
|
||||
layer2_rn: Conv2d,
|
||||
layer3_rn: Conv2d,
|
||||
layer4_rn: Conv2d,
|
||||
refine_net1: FeatureFusionBlock,
|
||||
refine_net2: FeatureFusionBlock,
|
||||
refine_net3: FeatureFusionBlock,
|
||||
refine_net4: FeatureFusionBlock,
|
||||
output_conv1: Conv2d,
|
||||
output_conv2: Sequential,
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||
const KERNEL_SIZE: usize = 3;
|
||||
let conv_cfg = Conv2dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
};
|
||||
|
||||
let layer1_rn = conv2d_no_bias(
|
||||
conf.out_channel_sizes[0],
|
||||
conf.num_features,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("layer1_rn"),
|
||||
)?;
|
||||
let layer2_rn = conv2d_no_bias(
|
||||
conf.out_channel_sizes[1],
|
||||
conf.num_features,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("layer2_rn"),
|
||||
)?;
|
||||
let layer3_rn = conv2d_no_bias(
|
||||
conf.out_channel_sizes[2],
|
||||
conf.num_features,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("layer3_rn"),
|
||||
)?;
|
||||
let layer4_rn = conv2d_no_bias(
|
||||
conf.out_channel_sizes[3],
|
||||
conf.num_features,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("layer4_rn"),
|
||||
)?;
|
||||
|
||||
let refine_net1 = FeatureFusionBlock::new(
|
||||
conf,
|
||||
conf.target_patch_size * 8,
|
||||
Activation::Relu,
|
||||
vb.pp("refinenet1"),
|
||||
)?;
|
||||
let refine_net2 = FeatureFusionBlock::new(
|
||||
conf,
|
||||
conf.target_patch_size * 4,
|
||||
Activation::Relu,
|
||||
vb.pp("refinenet2"),
|
||||
)?;
|
||||
let refine_net3 = FeatureFusionBlock::new(
|
||||
conf,
|
||||
conf.target_patch_size * 2,
|
||||
Activation::Relu,
|
||||
vb.pp("refinenet3"),
|
||||
)?;
|
||||
let refine_net4 = FeatureFusionBlock::new(
|
||||
conf,
|
||||
conf.target_patch_size,
|
||||
Activation::Relu,
|
||||
vb.pp("refinenet4"),
|
||||
)?;
|
||||
|
||||
let conv_cfg = Conv2dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
};
|
||||
let output_conv1 = conv2d(
|
||||
conf.num_features,
|
||||
conf.num_features / 2,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("output_conv1"),
|
||||
)?;
|
||||
|
||||
let output_conv2 = seq();
|
||||
const HEAD_FEATURES_2: usize = 32;
|
||||
const OUT_CHANNELS_2: usize = 1;
|
||||
const KERNEL_SIZE_2: usize = 1;
|
||||
let output_conv2 = output_conv2.add(conv2d(
|
||||
conf.num_features / 2,
|
||||
HEAD_FEATURES_2,
|
||||
KERNEL_SIZE,
|
||||
conv_cfg,
|
||||
vb.pp("output_conv2").pp("0"),
|
||||
)?);
|
||||
let output_conv2 = output_conv2
|
||||
.add(Activation::Relu)
|
||||
.add(conv2d(
|
||||
HEAD_FEATURES_2,
|
||||
OUT_CHANNELS_2,
|
||||
KERNEL_SIZE_2,
|
||||
conv_cfg,
|
||||
vb.pp("output_conv2").pp("2"),
|
||||
)?)
|
||||
.add(Activation::Relu);
|
||||
|
||||
Ok(Self {
|
||||
layer1_rn,
|
||||
layer2_rn,
|
||||
layer3_rn,
|
||||
layer4_rn,
|
||||
refine_net1,
|
||||
refine_net2,
|
||||
refine_net3,
|
||||
refine_net4,
|
||||
output_conv1,
|
||||
output_conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const NUM_CHANNELS: usize = 4;
|
||||
|
||||
pub struct DPTHead<'a> {
|
||||
conf: &'a DepthAnythingV2Config,
|
||||
projections: Vec<Conv2d>,
|
||||
resize_layers: Vec<Box<dyn Module>>,
|
||||
readout_projections: Vec<Sequential>,
|
||||
scratch: Scratch,
|
||||
}
|
||||
|
||||
impl<'a> DPTHead<'a> {
|
||||
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
|
||||
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
|
||||
projections.push(conv2d(
|
||||
conf.in_channel_size,
|
||||
*out_channel_size,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("projects").pp(conv_index.to_string()),
|
||||
)?);
|
||||
}
|
||||
|
||||
let resize_layers: Vec<Box<dyn Module>> = vec![
|
||||
Box::new(conv_transpose2d(
|
||||
conf.out_channel_sizes[0],
|
||||
conf.out_channel_sizes[0],
|
||||
4,
|
||||
ConvTranspose2dConfig {
|
||||
padding: 0,
|
||||
stride: 4,
|
||||
dilation: 1,
|
||||
output_padding: 0,
|
||||
},
|
||||
vb.pp("resize_layers").pp("0"),
|
||||
)?),
|
||||
Box::new(conv_transpose2d(
|
||||
conf.out_channel_sizes[1],
|
||||
conf.out_channel_sizes[1],
|
||||
2,
|
||||
ConvTranspose2dConfig {
|
||||
padding: 0,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
output_padding: 0,
|
||||
},
|
||||
vb.pp("resize_layers").pp("1"),
|
||||
)?),
|
||||
Box::new(Identity::new()),
|
||||
Box::new(conv2d(
|
||||
conf.out_channel_sizes[3],
|
||||
conf.out_channel_sizes[3],
|
||||
3,
|
||||
Conv2dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
},
|
||||
vb.pp("resize_layers").pp("3"),
|
||||
)?),
|
||||
];
|
||||
|
||||
let readout_projections = if conf.use_class_token {
|
||||
let rop = Vec::with_capacity(NUM_CHANNELS);
|
||||
for rop_index in 0..NUM_CHANNELS {
|
||||
seq()
|
||||
.add(linear(
|
||||
2 * conf.in_channel_size,
|
||||
conf.in_channel_size,
|
||||
vb.pp("readout_projects").pp(rop_index.to_string()),
|
||||
)?)
|
||||
.add(Activation::Gelu);
|
||||
}
|
||||
rop
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
|
||||
|
||||
Ok(Self {
|
||||
conf,
|
||||
projections,
|
||||
resize_layers,
|
||||
readout_projections,
|
||||
scratch,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DPTHead<'_> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
|
||||
for i in 0..NUM_CHANNELS {
|
||||
let x = if self.conf.use_class_token {
|
||||
let x = xs.get(i)?.get(0)?;
|
||||
let class_token = xs.get(i)?.get(1)?;
|
||||
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
|
||||
let to_cat = [x, readout];
|
||||
let cat = Tensor::cat(&to_cat, Minus1)?;
|
||||
self.readout_projections[i].forward(&cat)?
|
||||
} else {
|
||||
xs.get(i)?
|
||||
};
|
||||
let x_dims = x.dims();
|
||||
|
||||
let x = x.permute((0, 2, 1))?.reshape((
|
||||
x_dims[0],
|
||||
x_dims[x_dims.len() - 1],
|
||||
self.conf.target_patch_size,
|
||||
self.conf.target_patch_size,
|
||||
))?;
|
||||
let x = self.projections[i].forward(&x)?;
|
||||
|
||||
let x = self.resize_layers[i].forward(&x)?;
|
||||
out.push(x);
|
||||
}
|
||||
|
||||
let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
|
||||
let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
|
||||
let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
|
||||
let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
|
||||
|
||||
let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
|
||||
|
||||
let res3_out = self
|
||||
.scratch
|
||||
.refine_net3
|
||||
.res_conv_unit1
|
||||
.forward(&layer_3_rn)?;
|
||||
let res3_out = path4.add(&res3_out)?;
|
||||
let path3 = self.scratch.refine_net3.forward(&res3_out)?;
|
||||
|
||||
let res2_out = self
|
||||
.scratch
|
||||
.refine_net2
|
||||
.res_conv_unit1
|
||||
.forward(&layer_2_rn)?;
|
||||
let res2_out = path3.add(&res2_out)?;
|
||||
let path2 = self.scratch.refine_net2.forward(&res2_out)?;
|
||||
|
||||
let res1_out = self
|
||||
.scratch
|
||||
.refine_net1
|
||||
.res_conv_unit1
|
||||
.forward(&layer_1_rn)?;
|
||||
let res1_out = path2.add(&res1_out)?;
|
||||
let path1 = self.scratch.refine_net1.forward(&res1_out)?;
|
||||
|
||||
let out = self.scratch.output_conv1.forward(&path1)?;
|
||||
|
||||
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
|
||||
|
||||
self.scratch.output_conv2.forward(&out)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DepthAnythingV2<'a> {
|
||||
pretrained: &'a DinoVisionTransformer,
|
||||
depth_head: DPTHead<'a>,
|
||||
conf: &'a DepthAnythingV2Config,
|
||||
}
|
||||
|
||||
impl<'a> DepthAnythingV2<'a> {
|
||||
pub fn new(
|
||||
pretrained: &'a DinoVisionTransformer,
|
||||
conf: &'a DepthAnythingV2Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
|
||||
|
||||
Ok(Self {
|
||||
pretrained,
|
||||
depth_head,
|
||||
conf,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Module for DepthAnythingV2<'a> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let features = self.pretrained.get_intermediate_layers(
|
||||
xs,
|
||||
&self.conf.layer_ids_vits,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
)?;
|
||||
let depth = self.depth_head.forward(&features)?;
|
||||
|
||||
depth.relu()
|
||||
}
|
||||
}
|
@ -258,6 +258,84 @@ impl DinoVisionTransformer {
|
||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
||||
}
|
||||
|
||||
fn get_intermediate_layers_not_chunked(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
blocks_to_take: &[usize],
|
||||
) -> Result<Vec<Tensor>> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
let mut output = Vec::new();
|
||||
for (i, blk) in self.blocks.iter().enumerate() {
|
||||
xs = blk.forward(&xs)?;
|
||||
if blocks_to_take.contains(&i) {
|
||||
output.push(xs.clone());
|
||||
}
|
||||
}
|
||||
if output.len() != blocks_to_take.len() {
|
||||
candle::bail!(
|
||||
"only {} / {} blocks found",
|
||||
output.len(),
|
||||
blocks_to_take.len()
|
||||
);
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn get_intermediate_layers(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
blocks_to_take: &[usize],
|
||||
reshape: bool,
|
||||
return_class_token: bool,
|
||||
norm: bool,
|
||||
) -> Result<Tensor> {
|
||||
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
|
||||
let outputs = if norm {
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| self.norm.forward(out))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
let class_tokens = outputs
|
||||
.iter()
|
||||
.map(|out| out.i((.., 0)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.map(|out| out.i((.., 1..)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let outputs = if reshape {
|
||||
let (b, _c, w, h) = xs.dims4()?;
|
||||
let patch_size = self.patch_embed.patch_size.0;
|
||||
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| {
|
||||
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
|
||||
let outputs = if return_class_token {
|
||||
outputs
|
||||
.iter()
|
||||
.zip(class_tokens.iter())
|
||||
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
|
||||
Tensor::stack(&outputs[..], 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DinoVisionTransformer {
|
||||
|
281
candle-transformers/src/models/dinov2reg4.rs
Normal file
281
candle-transformers/src/models/dinov2reg4.rs
Normal file
@ -0,0 +1,281 @@
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const IMG_SIZE: usize = 518;
|
||||
const PATCH_SIZE: usize = 14;
|
||||
const NUM_CLASSES: usize = 7806; // PlantCLEF2024 DINOv2 (https://zenodo.org/records/10848263)
|
||||
|
||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
) -> Result<Self> {
|
||||
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||
Ok(Self {
|
||||
qkv,
|
||||
proj,
|
||||
num_heads,
|
||||
scale,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = self
|
||||
.qkv
|
||||
.forward(xs)?
|
||||
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)? // 02134
|
||||
.transpose(0, 1)? // 20134
|
||||
.transpose(2, 3)?; // 20314
|
||||
let q = (qkv.i(0)? * self.scale)?;
|
||||
let k = qkv.i(1)?.contiguous()?;
|
||||
let v = qkv.i(2)?.contiguous()?;
|
||||
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
|
||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LayerScale {
|
||||
gamma: Tensor,
|
||||
}
|
||||
|
||||
impl LayerScale {
|
||||
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
|
||||
let gamma = vb.get(dim, "gamma")?;
|
||||
Ok(Self { gamma })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerScale {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.broadcast_mul(&self.gamma)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
|
||||
let out_features = in_features;
|
||||
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
|
||||
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
|
||||
Ok(Self { fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.gelu()?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
ls1: LayerScale,
|
||||
norm2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
ls2: LayerScale,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
|
||||
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
|
||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
|
||||
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
ls1,
|
||||
norm2,
|
||||
mlp,
|
||||
ls2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self
|
||||
.ls1
|
||||
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = self
|
||||
.ls2
|
||||
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
patch_size: (usize, usize),
|
||||
num_patches: usize,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
) -> Result<Self> {
|
||||
let config = candle_nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
|
||||
let num_patches = (img_size / patch_size) * (img_size / patch_size);
|
||||
Ok(Self {
|
||||
proj,
|
||||
patch_size: (patch_size, patch_size),
|
||||
num_patches,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _c, h, w) = xs.dims4()?;
|
||||
let (patch_h, patch_w) = self.patch_size;
|
||||
if (h % patch_h) != 0 {
|
||||
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
|
||||
}
|
||||
if (w % patch_w) != 0 {
|
||||
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
|
||||
}
|
||||
let xs = self.proj.forward(xs)?;
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
// flatten embeddings.
|
||||
xs.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DinoVisionTransformer {
|
||||
patch_embed: PatchEmbed,
|
||||
cls_token: Tensor,
|
||||
reg_token: Tensor,
|
||||
pos_embed: Tensor,
|
||||
blocks: Vec<Block>,
|
||||
norm: LayerNorm,
|
||||
head: Linear,
|
||||
}
|
||||
|
||||
impl DinoVisionTransformer {
|
||||
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let patch_embed =
|
||||
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
|
||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||
let reg_token = vb.get((1, 4, embed_dim), "reg_token")?;
|
||||
let pos_embed = vb.get((1, patch_embed.num_patches, embed_dim), "pos_embed")?;
|
||||
let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
|
||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let blocks = (0..depth)
|
||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
cls_token,
|
||||
reg_token,
|
||||
pos_embed,
|
||||
blocks,
|
||||
norm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
||||
let npatch = xs.dim(1)? - 1;
|
||||
let n = self.pos_embed.dim(1)? - 1;
|
||||
let sqrt_n = (n as f64).sqrt();
|
||||
if npatch == n && w == h {
|
||||
return Ok(self.pos_embed.clone());
|
||||
}
|
||||
let patch_pos_embed = &self.pos_embed;
|
||||
let dim = xs.dim(D::Minus1)?;
|
||||
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
||||
let patch_pos_embed = patch_pos_embed
|
||||
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)?;
|
||||
// This uses bicubic interpolation in the original implementation.
|
||||
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
||||
let el_count = patch_pos_embed.shape().elem_count();
|
||||
patch_pos_embed
|
||||
.transpose(1, 2)?
|
||||
.transpose(2, 3)?
|
||||
.reshape((1, el_count / dim, dim))
|
||||
}
|
||||
|
||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _nc, w, h) = xs.dims4()?;
|
||||
if (w != IMG_SIZE) || (h != IMG_SIZE) {
|
||||
panic!("Error: The input tensor should have the shape: Bx3x518x518.");
|
||||
}
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h)?)?;
|
||||
let xs = Tensor::cat(&[&self.cls_token, &self.reg_token, &xs], 1)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DinoVisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
for blk in self.blocks.iter() {
|
||||
xs = blk.forward(&xs)?
|
||||
}
|
||||
let xs = self.norm.forward(&xs)?;
|
||||
let xs_norm_clstoken = xs.i((.., 0))?;
|
||||
self.head.forward(&xs_norm_clstoken)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
|
||||
DinoVisionTransformer::new(vb, 12, 384, 6)
|
||||
}
|
||||
|
||||
pub fn vit_base(vb: VarBuilder) -> Result<DinoVisionTransformer> {
|
||||
DinoVisionTransformer::new(vb, 12, 768, 12)
|
||||
}
|
418
candle-transformers/src/models/eva2.rs
Normal file
418
candle-transformers/src/models/eva2.rs
Normal file
@ -0,0 +1,418 @@
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const IMG_SIZE: usize = 448;
|
||||
const PATCH_SIZE: usize = 14;
|
||||
const NUM_CLASSES: usize = 1000;
|
||||
|
||||
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
q: Linear,
|
||||
k: Linear,
|
||||
v: Linear,
|
||||
proj: Linear,
|
||||
rot_pos_embed: Tensor,
|
||||
num_heads: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
dim: usize,
|
||||
num_heads: usize,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
rot_pos_embed: &Tensor,
|
||||
) -> Result<Self> {
|
||||
let q = linear(vb.pp("q_proj"), dim, dim, qkv_bias)?;
|
||||
let k = linear(vb.pp("k_proj"), dim, dim, false)?; // no bias for Key
|
||||
let v = linear(vb.pp("v_proj"), dim, dim, qkv_bias)?;
|
||||
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
|
||||
let rot_pos_embed = rot_pos_embed.clone();
|
||||
let scale = 1. / ((dim / num_heads) as f64).sqrt();
|
||||
Ok(Self {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj,
|
||||
rot_pos_embed,
|
||||
num_heads,
|
||||
scale,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
// See: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed_sincos.py#L210
|
||||
fn apply_rot_embed_cat(x: &Tensor, emb: &Tensor) -> Result<Tensor> {
|
||||
let cos_emb = emb.i((0.., 64..128))?; //.transpose(0, 1)?;
|
||||
let sin_emb = emb.i((0.., 0..64))?; //.transpose(0, 1)?;
|
||||
let index_even: [u32; 32] = (0u32..=63)
|
||||
.step_by(2)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("wrong size iterator");
|
||||
let index_odd: [u32; 32] = (1u32..=63)
|
||||
.step_by(2)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.expect("wrong size iterator");
|
||||
let t_index_even = Tensor::new(&index_even, x.device())?;
|
||||
let t_index_odd = Tensor::new(&index_odd, x.device())?;
|
||||
let x_c = x.contiguous()?;
|
||||
let rot_x_even = x_c.index_select(&t_index_even, D::Minus1)?;
|
||||
let rot_x_odd_minus = (-1.0 * x_c.index_select(&t_index_odd, D::Minus1)?)?;
|
||||
let rot_x =
|
||||
Tensor::stack(&[&rot_x_odd_minus, &rot_x_even], D::Minus1)?.reshape(x.shape())?;
|
||||
x.broadcast_mul(&cos_emb)? + rot_x.broadcast_mul(&sin_emb)?
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = Tensor::cat(
|
||||
&[
|
||||
&self.q.forward(xs)?,
|
||||
&self.k.forward(xs)?,
|
||||
&self.v.forward(xs)?,
|
||||
],
|
||||
2,
|
||||
)?
|
||||
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)? // 02134
|
||||
.transpose(0, 1)? // 20134
|
||||
.transpose(2, 3)?; // 20314
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?.contiguous()?;
|
||||
let v = qkv.i(2)?.contiguous()?;
|
||||
|
||||
let npt = 1; // num_prefix_tokens = 1 for CLS token
|
||||
let q = Tensor::cat(
|
||||
&[
|
||||
&q.i((0.., 0.., ..npt, 0..))?,
|
||||
&Self::apply_rot_embed_cat(&q.i((0.., 0.., npt.., 0..))?, &self.rot_pos_embed)?,
|
||||
],
|
||||
2,
|
||||
)?;
|
||||
let k = Tensor::cat(
|
||||
&[
|
||||
&k.i((0.., 0.., ..npt, 0..))?,
|
||||
&Self::apply_rot_embed_cat(&k.i((0.., 0.., npt.., 0..))?, &self.rot_pos_embed)?,
|
||||
],
|
||||
2,
|
||||
)?;
|
||||
|
||||
let q = (q * self.scale)?;
|
||||
let attn = &q.matmul(&k.t()?)?;
|
||||
let attn = candle_nn::ops::softmax(attn, D::Minus1)?;
|
||||
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
|
||||
self.proj.forward(&attn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
fc1_g: Linear,
|
||||
fc1_x: Linear,
|
||||
norm: LayerNorm,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
|
||||
let out_features = in_features;
|
||||
let fc1_g = linear(vb.pp("fc1_g"), in_features, hidden_features, bias)?;
|
||||
let fc1_x = linear(vb.pp("fc1_x"), in_features, hidden_features, bias)?;
|
||||
let norm = layer_norm(hidden_features, 1e-6, vb.pp("norm"))?;
|
||||
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
|
||||
Ok(Self {
|
||||
fc1_g,
|
||||
fc1_x,
|
||||
norm,
|
||||
fc2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs_g = self.fc1_g.forward(xs)?.silu()?;
|
||||
let xs = self.fc1_x.forward(xs)?;
|
||||
let xs = self.norm.forward(&(xs_g.mul(&xs)?))?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
norm1: LayerNorm,
|
||||
attn: Attention,
|
||||
norm2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(vb: VarBuilder, dim: usize, num_heads: usize, rot_pos_embed: &Tensor) -> Result<Self> {
|
||||
let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
|
||||
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true, rot_pos_embed)?;
|
||||
let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
|
||||
let hidden_dim = dim * 4 * 2 / 3; // 768 * 4 * 2 / 3 = 3072 * 2 / 3 = 2048
|
||||
let mlp = Mlp::new(vb.pp("mlp"), dim, hidden_dim, true)?;
|
||||
Ok(Self {
|
||||
norm1,
|
||||
attn,
|
||||
norm2,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = &self.attn.forward(&self.norm1.forward(xs)?)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = &self.mlp.forward(&self.norm2.forward(&xs)?)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
proj: candle_nn::Conv2d,
|
||||
patch_size: (usize, usize),
|
||||
num_patches: usize,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(
|
||||
vb: VarBuilder,
|
||||
img_size: usize,
|
||||
patch_size: usize,
|
||||
in_chans: usize,
|
||||
embed_dim: usize,
|
||||
) -> Result<Self> {
|
||||
let config = candle_nn::Conv2dConfig {
|
||||
stride: patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
|
||||
let num_patches = (img_size / patch_size) * (img_size / patch_size);
|
||||
Ok(Self {
|
||||
proj,
|
||||
patch_size: (patch_size, patch_size),
|
||||
num_patches,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _c, h, w) = xs.dims4()?;
|
||||
let (patch_h, patch_w) = self.patch_size;
|
||||
if (h % patch_h) != 0 {
|
||||
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
|
||||
}
|
||||
if (w % patch_w) != 0 {
|
||||
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
|
||||
}
|
||||
let xs = self.proj.forward(xs)?;
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
// flatten embeddings.
|
||||
xs.reshape((b, c, h * w))?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EVA2VisionTransformer {
|
||||
patch_embed: PatchEmbed,
|
||||
cls_token: Tensor,
|
||||
pos_embed: Tensor,
|
||||
blocks: Vec<Block>,
|
||||
norm: LayerNorm,
|
||||
head: Linear,
|
||||
}
|
||||
|
||||
impl EVA2VisionTransformer {
|
||||
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
|
||||
let patch_embed =
|
||||
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
|
||||
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
|
||||
let pos_embed = vb.get((1, patch_embed.num_patches + 1, embed_dim), "pos_embed")?;
|
||||
let rot_pos_embed = vb.get((patch_embed.num_patches, 128), "rot_pos_embed")?;
|
||||
let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?;
|
||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||
let vb_b = vb.pp("blocks");
|
||||
let blocks = (0..depth)
|
||||
.map(|i| {
|
||||
Block::new(
|
||||
vb_b.pp(&i.to_string()),
|
||||
embed_dim,
|
||||
num_heads,
|
||||
&rot_pos_embed,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
cls_token,
|
||||
pos_embed,
|
||||
blocks,
|
||||
norm,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
fn interpolate_pos_encoding(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
w: usize,
|
||||
h: usize,
|
||||
num_prefix_tokens: usize,
|
||||
) -> Result<Tensor> {
|
||||
let npatch = xs.dim(1)? - 1;
|
||||
let n = self.pos_embed.dim(1)? - 1;
|
||||
let sqrt_n = (n as f64).sqrt();
|
||||
if npatch == n && w == h {
|
||||
return Ok(self.pos_embed.clone());
|
||||
}
|
||||
// Interpolate only local tokens, i.e. those after the CLS token
|
||||
let prefix_tokens_pos_embed = self.pos_embed.i((0.., ..num_prefix_tokens, 0..))?.clone();
|
||||
let patch_pos_embed = &self.pos_embed.i((0.., num_prefix_tokens.., 0..))?;
|
||||
let dim = xs.dim(D::Minus1)?;
|
||||
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
|
||||
let patch_pos_embed = patch_pos_embed
|
||||
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)?;
|
||||
// This uses bicubic interpolation in the original implementation.
|
||||
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
|
||||
let el_count = patch_pos_embed.shape().elem_count();
|
||||
let patch_pos_embed =
|
||||
patch_pos_embed
|
||||
.transpose(1, 2)?
|
||||
.transpose(2, 3)?
|
||||
.reshape((1, el_count / dim, dim))?;
|
||||
Tensor::cat(&[&prefix_tokens_pos_embed, &patch_pos_embed], 1)
|
||||
}
|
||||
|
||||
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _nc, w, h) = xs.dims4()?;
|
||||
if (w != IMG_SIZE) || (h != IMG_SIZE) {
|
||||
panic!("Error: The input tensor should have the shape: Bx3x518x518.");
|
||||
}
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||
let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h, 1)?)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn get_intermediate_layers_not_chunked(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
blocks_to_take: &[usize],
|
||||
) -> Result<Vec<Tensor>> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
let mut output = Vec::new();
|
||||
for (i, blk) in self.blocks.iter().enumerate() {
|
||||
xs = blk.forward(&xs)?;
|
||||
if blocks_to_take.contains(&i) {
|
||||
output.push(xs.clone());
|
||||
}
|
||||
}
|
||||
if output.len() != blocks_to_take.len() {
|
||||
candle::bail!(
|
||||
"only {} / {} blocks found",
|
||||
output.len(),
|
||||
blocks_to_take.len()
|
||||
);
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn get_intermediate_layers(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
blocks_to_take: &[usize],
|
||||
reshape: bool,
|
||||
return_class_token: bool,
|
||||
norm: bool,
|
||||
) -> Result<Tensor> {
|
||||
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
|
||||
let outputs = if norm {
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| self.norm.forward(out))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
let class_tokens = outputs
|
||||
.iter()
|
||||
.map(|out| out.i((.., 0)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.map(|out| out.i((.., 1..)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let outputs = if reshape {
|
||||
let (b, _c, w, h) = xs.dims4()?;
|
||||
let patch_size = self.patch_embed.patch_size.0;
|
||||
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
|
||||
outputs
|
||||
.iter()
|
||||
.map(|out| {
|
||||
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
|
||||
.transpose(2, 3)?
|
||||
.transpose(1, 2)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
|
||||
let outputs = if return_class_token {
|
||||
outputs
|
||||
.iter()
|
||||
.zip(class_tokens.iter())
|
||||
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
} else {
|
||||
outputs
|
||||
};
|
||||
|
||||
Tensor::stack(&outputs[..], 0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EVA2VisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||
for blk in self.blocks.iter() {
|
||||
xs = blk.forward(&xs)?
|
||||
}
|
||||
let xs_moy_local_tokens = xs.i((.., 1..))?.mean(1)?;
|
||||
let xs_norm = self.norm.forward(&xs_moy_local_tokens)?;
|
||||
self.head.forward(&xs_norm)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vit_base(vb: VarBuilder) -> Result<EVA2VisionTransformer> {
|
||||
EVA2VisionTransformer::new(vb, 12, 768, 12)
|
||||
}
|
||||
|
||||
pub fn vit_large(vb: VarBuilder) -> Result<EVA2VisionTransformer> {
|
||||
EVA2VisionTransformer::new(vb, 24, 1024, 16)
|
||||
}
|
@ -388,6 +388,28 @@ pub struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
// required by LLaVA
|
||||
pub fn embed(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.wte.forward(x)
|
||||
}
|
||||
// required by LLaVA
|
||||
pub fn forward_input_embed(
|
||||
&self,
|
||||
input_embed: &Tensor,
|
||||
index_pos: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
let (_, seq_len, _) = input_embed.dims3()?;
|
||||
let mut x = input_embed.clone();
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
|
267
candle-transformers/src/models/llava/config.rs
Normal file
267
candle-transformers/src/models/llava/config.rs
Normal file
@ -0,0 +1,267 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::models::{
|
||||
clip::{text_model::Activation, vision_model::ClipVisionConfig},
|
||||
llama::Config,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// original config from liuhaotian/llava
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct LLaVAConfig {
|
||||
pub architectures: Vec<String>,
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_image_aspect_ratio")]
|
||||
pub image_aspect_ratio: String,
|
||||
pub image_crop_resolution: usize,
|
||||
pub image_grid_pinpoints: Vec<(u32, u32)>,
|
||||
pub image_split_resolution: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub mm_hidden_size: usize,
|
||||
#[serde(default = "default_mm_patch_merge_type")]
|
||||
pub mm_patch_merge_type: String,
|
||||
pub mm_projector_type: String,
|
||||
pub mm_use_im_start_end: bool,
|
||||
pub mm_vision_select_feature: String,
|
||||
pub mm_vision_select_layer: isize,
|
||||
pub mm_vision_tower: Option<String>,
|
||||
pub model_type: String,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub pad_token_id: usize,
|
||||
pub rms_norm_eps: f32,
|
||||
pub rope_theta: f32,
|
||||
pub tokenizer_model_max_length: Option<usize>,
|
||||
pub torch_dtype: String,
|
||||
pub use_cache: bool,
|
||||
pub vocab_size: usize,
|
||||
#[serde(default = "default_image_token_index")]
|
||||
pub image_token_index: isize,
|
||||
#[serde(default = "default_hf")]
|
||||
pub hf: bool,
|
||||
}
|
||||
|
||||
fn default_hf() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_image_token_index() -> isize {
|
||||
-200
|
||||
}
|
||||
|
||||
fn default_mm_patch_merge_type() -> String {
|
||||
"flat".to_string()
|
||||
}
|
||||
|
||||
fn default_image_aspect_ratio() -> String {
|
||||
"square".to_string()
|
||||
}
|
||||
|
||||
impl LLaVAConfig {
|
||||
pub fn to_llama_config(&self) -> Config {
|
||||
Config {
|
||||
hidden_size: self.hidden_size,
|
||||
intermediate_size: self.intermediate_size,
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads,
|
||||
rms_norm_eps: self.rms_norm_eps as f64,
|
||||
rope_theta: self.rope_theta,
|
||||
bos_token_id: Some(self.bos_token_id as u32),
|
||||
eos_token_id: Some(self.eos_token_id as u32),
|
||||
use_flash_attn: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFLLaVATextConfig {
|
||||
pub architectures: Vec<String>,
|
||||
#[serde(default = "default_hidden_size")]
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_intermediate_size")]
|
||||
pub intermediate_size: usize,
|
||||
#[serde(default = "default_max_length")]
|
||||
pub max_length: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
pub model_type: String,
|
||||
#[serde(default = "default_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
#[serde(default = "default_num_hidden_layers")]
|
||||
pub num_hidden_layers: usize,
|
||||
#[serde(default = "default_num_key_value_heads")]
|
||||
pub num_key_value_heads: usize,
|
||||
pub pad_token_id: usize,
|
||||
pub rms_norm_eps: f32,
|
||||
#[serde(default = "default_rope_theta")]
|
||||
pub rope_theta: f32,
|
||||
pub torch_dtype: String,
|
||||
#[serde(default = "default_use_cache")]
|
||||
pub use_cache: bool,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
fn default_num_hidden_layers() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_use_cache() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_hidden_size() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_intermediate_size() -> usize {
|
||||
11008
|
||||
}
|
||||
|
||||
fn default_max_length() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_num_key_value_heads() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_rope_theta() -> f32 {
|
||||
10000.0
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFLLaVAVisionConfig {
|
||||
pub hidden_size: usize,
|
||||
pub image_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub model_type: String,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub patch_size: usize,
|
||||
pub projection_dim: usize,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
// config from llava-v1.6-vicuna-7b-hf
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFLLaVAConfig {
|
||||
pub architectures: Vec<String>,
|
||||
pub ignore_index: isize,
|
||||
pub image_grid_pinpoints: Vec<(u32, u32)>,
|
||||
pub image_token_index: isize,
|
||||
pub model_type: String,
|
||||
pub projector_hidden_act: String,
|
||||
pub text_config: HFLLaVATextConfig,
|
||||
pub torch_dtype: String,
|
||||
pub use_image_newline_parameter: bool,
|
||||
pub vision_config: HFLLaVAVisionConfig,
|
||||
pub vision_feature_layer: isize,
|
||||
pub vision_feature_select_strategy: String,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFGenerationConfig {
|
||||
pub bos_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
#[serde(default = "default_max_length")]
|
||||
pub max_length: usize,
|
||||
pub pad_token_id: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HFPreProcessorConfig {
|
||||
pub aspect_ratio_setting: String,
|
||||
pub crop_size: HashMap<String, usize>,
|
||||
pub do_center_crop: bool,
|
||||
pub do_convert_rgb: bool,
|
||||
pub do_normalize: bool,
|
||||
pub do_rescale: bool,
|
||||
pub do_resize: bool,
|
||||
pub image_mean: Vec<f32>,
|
||||
pub image_std: Vec<f32>,
|
||||
pub resample: u32,
|
||||
pub rescale_factor: f32,
|
||||
pub size: HashMap<String, f32>,
|
||||
}
|
||||
|
||||
impl HFLLaVAConfig {
|
||||
pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
|
||||
ClipVisionConfig {
|
||||
embed_dim: self.vision_config.hidden_size,
|
||||
activation: Activation::QuickGelu,
|
||||
intermediate_size: self.vision_config.intermediate_size,
|
||||
num_hidden_layers: self.vision_config.num_hidden_layers,
|
||||
num_attention_heads: self.vision_config.num_attention_heads,
|
||||
projection_dim: self.vision_config.projection_dim,
|
||||
num_channels: 3,
|
||||
image_size: self.vision_config.image_size,
|
||||
patch_size: self.vision_config.patch_size,
|
||||
}
|
||||
}
|
||||
fn map_projector_type(s: &str) -> String {
|
||||
if s == "gelu" {
|
||||
"mlp2x_gelu".to_string()
|
||||
} else {
|
||||
s.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn map_select_feature(s: &str) -> String {
|
||||
if s == "default" {
|
||||
"patch".to_string()
|
||||
} else {
|
||||
"cls_patch".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_llava_config(
|
||||
&self,
|
||||
generation_config: &HFGenerationConfig,
|
||||
preprocessor_config: &HFPreProcessorConfig,
|
||||
) -> LLaVAConfig {
|
||||
LLaVAConfig {
|
||||
hf: true,
|
||||
architectures: self.architectures.clone(),
|
||||
bos_token_id: generation_config.bos_token_id,
|
||||
eos_token_id: generation_config.eos_token_id,
|
||||
hidden_size: self.text_config.hidden_size,
|
||||
image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
|
||||
image_crop_resolution: 224,
|
||||
image_grid_pinpoints: self.image_grid_pinpoints.clone(),
|
||||
image_split_resolution: 224,
|
||||
intermediate_size: self.text_config.intermediate_size,
|
||||
max_position_embeddings: self.text_config.max_position_embeddings,
|
||||
mm_hidden_size: 1024,
|
||||
mm_patch_merge_type: "spatial_unpad".to_string(),
|
||||
mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
|
||||
mm_use_im_start_end: false,
|
||||
mm_vision_select_feature: Self::map_select_feature(
|
||||
&self.vision_feature_select_strategy,
|
||||
),
|
||||
mm_vision_select_layer: self.vision_feature_layer,
|
||||
mm_vision_tower: None,
|
||||
model_type: self.model_type.clone(),
|
||||
num_attention_heads: self.text_config.num_attention_heads,
|
||||
num_hidden_layers: self.text_config.num_hidden_layers,
|
||||
num_key_value_heads: self.text_config.num_key_value_heads,
|
||||
pad_token_id: self.text_config.pad_token_id,
|
||||
rms_norm_eps: self.text_config.rms_norm_eps,
|
||||
rope_theta: self.text_config.rope_theta,
|
||||
tokenizer_model_max_length: Some(4096),
|
||||
torch_dtype: self.torch_dtype.clone(),
|
||||
use_cache: self.text_config.use_cache,
|
||||
vocab_size: self.vocab_size,
|
||||
image_token_index: self.image_token_index,
|
||||
}
|
||||
}
|
||||
}
|
407
candle-transformers/src/models/llava/mod.rs
Normal file
407
candle-transformers/src/models/llava/mod.rs
Normal file
@ -0,0 +1,407 @@
|
||||
pub mod config;
|
||||
pub mod utils;
|
||||
|
||||
use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
|
||||
use crate::models::llama::{Cache, Llama};
|
||||
use crate::models::with_tracing::linear;
|
||||
|
||||
use candle::{bail, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
|
||||
use fancy_regex::Regex;
|
||||
use utils::get_anyres_image_grid_shape;
|
||||
|
||||
use config::LLaVAConfig;
|
||||
|
||||
fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
|
||||
let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();
|
||||
|
||||
if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
|
||||
if let Some(match_str) = captures.get(1) {
|
||||
let match_str = match_str.as_str();
|
||||
match_str.parse::<usize>().ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
|
||||
assert_eq!(tensor.dims().len(), 3);
|
||||
let (original_width, original_height) = *original_size;
|
||||
let tensor_dims = tensor.dims();
|
||||
let current_height = tensor_dims[1];
|
||||
let current_width = tensor_dims[2];
|
||||
let original_aspect_ratio = (original_width as f32) / (original_height as f32);
|
||||
let current_aspect_ratio = (current_width as f32) / (current_height as f32);
|
||||
if original_aspect_ratio > current_aspect_ratio {
|
||||
let scale_factor = (current_width as f32) / (original_width as f32);
|
||||
let new_height = (original_height as f32 * scale_factor).floor() as usize;
|
||||
let padding = (current_height - new_height) / 2;
|
||||
tensor.i((.., padding..current_width - padding, ..))
|
||||
} else {
|
||||
let scale_factor = (current_height as f32) / (original_height as f32);
|
||||
let new_width = (original_width as f32 * scale_factor).floor() as usize;
|
||||
let padding = (current_width - new_width) / 2;
|
||||
tensor.i((.., .., padding..current_width - padding))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IdentityMap {}
|
||||
|
||||
impl Module for IdentityMap {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MMProjector {
|
||||
pub modules: Sequential,
|
||||
}
|
||||
|
||||
impl MMProjector {
|
||||
pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
|
||||
if config.mm_projector_type == "linear" {
|
||||
let vb_prefix = if config.hf {
|
||||
"multi_modal_projector.linear_1"
|
||||
} else {
|
||||
"model.mm_projector.0"
|
||||
};
|
||||
let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
|
||||
let modules = seq().add(linear);
|
||||
Ok(Self { modules })
|
||||
} else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
|
||||
let modules = if config.hf {
|
||||
let mut modules = seq().add(linear(
|
||||
config.mm_hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp("multi_modal_projector.linear_1"),
|
||||
)?);
|
||||
for i in 1..mlp_depth {
|
||||
modules = modules.add(Activation::Gelu).add(linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
|
||||
)?);
|
||||
}
|
||||
modules
|
||||
} else {
|
||||
let mut modules = seq().add(linear(
|
||||
config.mm_hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp("model.mm_projector.0"),
|
||||
)?);
|
||||
for i in 1..mlp_depth {
|
||||
modules = modules.add(Activation::Gelu).add(linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
vb.pp(format!("model.mm_projector.{}", i * 2)),
|
||||
)?);
|
||||
}
|
||||
modules
|
||||
};
|
||||
Ok(Self { modules })
|
||||
} else if config.mm_projector_type == "identity" {
|
||||
Ok(Self {
|
||||
modules: seq().add(IdentityMap {}),
|
||||
})
|
||||
} else {
|
||||
bail!(
|
||||
"Unsupported MM projector type: {}",
|
||||
config.mm_projector_type
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.modules.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClipVisionTower {
|
||||
model: ClipVisionTransformer,
|
||||
select_layer: isize,
|
||||
select_feature_method: String,
|
||||
pub config: ClipVisionConfig,
|
||||
}
|
||||
|
||||
impl ClipVisionTower {
|
||||
pub fn new(
|
||||
vb: VarBuilder,
|
||||
select_layer: isize,
|
||||
select_feature_method: &str,
|
||||
config: &Option<ClipVisionConfig>,
|
||||
) -> Result<Self> {
|
||||
let config = if config.is_none() {
|
||||
ClipVisionConfig::clip_vit_large_patch14_336()
|
||||
} else {
|
||||
config.clone().unwrap()
|
||||
};
|
||||
let select_layer = match select_layer {
|
||||
-1 | -2 => select_layer,
|
||||
_ => bail!("Unsupported select layer: {}", select_layer),
|
||||
};
|
||||
let model = ClipVisionTransformer::new(vb, &config)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
select_layer,
|
||||
select_feature_method: select_feature_method.to_string(),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let result = self.model.output_hidden_states(x)?;
|
||||
let index = result.len() as isize + self.select_layer;
|
||||
let result = result[index as usize].clone();
|
||||
if self.select_feature_method == "cls_patch" {
|
||||
Ok(result)
|
||||
} else {
|
||||
result.i((.., 1..))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_patches_per_side(&self) -> usize {
|
||||
self.config.image_size / self.config.patch_size
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LLaVA {
|
||||
pub clip_vision_tower: ClipVisionTower,
|
||||
pub image_newline: Tensor,
|
||||
pub mm_projector: MMProjector,
|
||||
pub llama: Llama,
|
||||
config: LLaVAConfig,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl LLaVA {
|
||||
pub fn load(
|
||||
vb: VarBuilder,
|
||||
config: &LLaVAConfig,
|
||||
clip_vision_config: Option<ClipVisionConfig>,
|
||||
) -> Result<Self> {
|
||||
let device = vb.device().clone();
|
||||
let llama_config = config.to_llama_config();
|
||||
let mm_projector = MMProjector::load(&vb, config)?;
|
||||
let (clip_vision_tower, image_newline, llama) = if config.hf {
|
||||
(
|
||||
ClipVisionTower::new(
|
||||
vb.pp("vision_tower.vision_model"),
|
||||
config.mm_vision_select_layer,
|
||||
&config.mm_vision_select_feature,
|
||||
&clip_vision_config,
|
||||
)?,
|
||||
vb.get(&[config.hidden_size], "image_newline")?
|
||||
.to_device(&device)?,
|
||||
Llama::load(vb.pp("language_model"), &llama_config)?,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
ClipVisionTower::new(
|
||||
vb.pp("model.vision_tower.vision_tower.vision_model"),
|
||||
config.mm_vision_select_layer,
|
||||
&config.mm_vision_select_feature,
|
||||
&clip_vision_config,
|
||||
)?,
|
||||
vb.get(&[config.hidden_size], "model.image_newline")?
|
||||
.to_device(&device)?,
|
||||
Llama::load(vb, &llama_config)?,
|
||||
)
|
||||
};
|
||||
Ok(Self {
|
||||
clip_vision_tower,
|
||||
image_newline,
|
||||
mm_projector,
|
||||
llama,
|
||||
config: (*config).clone(),
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let image_features = self.clip_vision_tower.forward(x)?;
|
||||
let image_features = self.mm_projector.forward(&image_features)?;
|
||||
Ok(image_features)
|
||||
}
|
||||
// currently only for single image, 4 dim tensor
|
||||
pub fn prepare_inputs_labels_for_multimodal(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
images: &[Tensor],
|
||||
image_sizes: &[(u32, u32)],
|
||||
) -> Result<Tensor> {
|
||||
//TODO: process of multiple images/ new line
|
||||
// 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576
|
||||
let concat_images = Tensor::cat(images, 0)?;
|
||||
let image_features_together = self.encode_images(&concat_images)?;
|
||||
let split_sizes = images
|
||||
.iter()
|
||||
.map(|x| x.shape().dims()[0])
|
||||
.collect::<Vec<usize>>();
|
||||
// can be replaced by split
|
||||
let mut index_pos = 0;
|
||||
let mut image_features = Vec::new();
|
||||
for split_size in split_sizes.iter() {
|
||||
image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
|
||||
index_pos += *split_size;
|
||||
}
|
||||
let mm_patch_merge_type = &self.config.mm_patch_merge_type;
|
||||
let image_aspect_ratio = &self.config.image_aspect_ratio;
|
||||
|
||||
let image_features = if mm_patch_merge_type == "flat" {
|
||||
image_features
|
||||
.iter()
|
||||
.map(|x| x.flatten(0, 1).unwrap())
|
||||
.collect::<Vec<Tensor>>()
|
||||
} else if mm_patch_merge_type.starts_with("spatial") {
|
||||
let mut new_image_features = Vec::new();
|
||||
for (image_idx, image_feature) in image_features.iter().enumerate() {
|
||||
let new_image_feature = if image_feature.dims()[0] > 1 {
|
||||
let base_image_feature = image_feature.get(0).unwrap();
|
||||
let patch_image_feature = image_feature.i(1..).unwrap();
|
||||
let height = self.clip_vision_tower.num_patches_per_side();
|
||||
let width = height;
|
||||
assert_eq!(height * width, base_image_feature.dims()[0]);
|
||||
let image_size = image_sizes[image_idx];
|
||||
let new_image_feature = if image_aspect_ratio == "anyres" {
|
||||
let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
|
||||
image_size,
|
||||
&self.config.image_grid_pinpoints,
|
||||
self.clip_vision_tower.config.image_size as u32,
|
||||
);
|
||||
patch_image_feature.reshape((
|
||||
num_patch_height as usize,
|
||||
num_patch_width as usize,
|
||||
height,
|
||||
width,
|
||||
(),
|
||||
))?
|
||||
} else {
|
||||
todo!("not implemented in original python LLaVA yet")
|
||||
};
|
||||
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
|
||||
let new_image_feature = new_image_feature
|
||||
.permute((4, 0, 2, 1, 3))?
|
||||
.flatten(1, 2)?
|
||||
.flatten(2, 3)?;
|
||||
let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
|
||||
let new_image_feature_dims = new_image_feature.dims();
|
||||
let image_new_line = self
|
||||
.image_newline
|
||||
.reshape((self.config.hidden_size, 1, 1))?
|
||||
.broadcast_as((
|
||||
new_image_feature_dims[0],
|
||||
new_image_feature_dims[1],
|
||||
1,
|
||||
))?;
|
||||
let new_image_feature =
|
||||
Tensor::cat(&[new_image_feature, image_new_line], 2)?;
|
||||
new_image_feature.flatten(1, 2)?.transpose(0, 1)?
|
||||
} else {
|
||||
new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
|
||||
};
|
||||
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
|
||||
} else {
|
||||
let new_image_feature = image_feature.get(0).unwrap();
|
||||
if mm_patch_merge_type.contains("unpad") {
|
||||
Tensor::cat(
|
||||
&[
|
||||
new_image_feature,
|
||||
self.image_newline.clone().unsqueeze(0).unwrap(),
|
||||
],
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
} else {
|
||||
new_image_feature
|
||||
}
|
||||
};
|
||||
new_image_features.push(new_image_feature);
|
||||
}
|
||||
new_image_features
|
||||
} else {
|
||||
bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
|
||||
};
|
||||
// can easily be replaced by nonzero if it is implemented in candle
|
||||
let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
|
||||
let mut image_indices = {
|
||||
let mut image_indices = vec![0_i64];
|
||||
image_indices.extend(
|
||||
input_ids_vec
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, x)| {
|
||||
if *x == self.config.image_token_index as i64 {
|
||||
Some(i as i64)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<i64>>(),
|
||||
);
|
||||
image_indices
|
||||
};
|
||||
if image_indices.len() == 1 {
|
||||
//no image, only [0],
|
||||
return self.llama.embed(input_ids);
|
||||
}
|
||||
|
||||
let input_ids_noim = input_ids_vec
|
||||
.iter()
|
||||
.filter_map(|x| {
|
||||
if *x != self.config.image_token_index as i64 {
|
||||
Some(*x)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<i64>>();
|
||||
let input_ids_noim_len = input_ids_noim.len();
|
||||
image_indices.push((input_ids_noim_len) as i64);
|
||||
let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
|
||||
let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
|
||||
// can be replace by split if it is implemented in candle
|
||||
let input_embed_no_ims = {
|
||||
let mut input_embeds = Vec::new();
|
||||
for i in 0..image_indices.len() - 1 {
|
||||
let start = (image_indices[i]) as usize;
|
||||
let end = image_indices[i + 1] as usize;
|
||||
input_embeds.push(cur_input_embeds.i((start..end, ..))?)
|
||||
}
|
||||
input_embeds
|
||||
};
|
||||
|
||||
let mut cur_new_input_embeds = Vec::new();
|
||||
for (i, image_feature) in image_features.iter().enumerate() {
|
||||
cur_new_input_embeds.push(input_embed_no_ims[i].clone());
|
||||
cur_new_input_embeds.push(image_feature.clone());
|
||||
}
|
||||
cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
|
||||
let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
|
||||
//trancate
|
||||
let new_input_embeds =
|
||||
if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
|
||||
let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
|
||||
if new_input_embeds_length > tokenizer_model_max_length {
|
||||
new_input_embeds.i((..tokenizer_model_max_length, ..))?
|
||||
} else {
|
||||
new_input_embeds
|
||||
}
|
||||
} else {
|
||||
new_input_embeds
|
||||
};
|
||||
new_input_embeds.unsqueeze(0)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
input_embeds: &Tensor,
|
||||
position_id: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
self.llama
|
||||
.forward_input_embed(input_embeds, position_id, cache)
|
||||
}
|
||||
}
|
41
candle-transformers/src/models/llava/utils.rs
Normal file
41
candle-transformers/src/models/llava/utils.rs
Normal file
@ -0,0 +1,41 @@
|
||||
pub fn get_anyres_image_grid_shape(
|
||||
image_size: (u32, u32),
|
||||
grid_pinpoints: &[(u32, u32)],
|
||||
patch_size: u32,
|
||||
) -> (u32, u32) {
|
||||
let (width, height) = select_best_resolution(image_size, grid_pinpoints);
|
||||
(width / patch_size, height / patch_size)
|
||||
}
|
||||
|
||||
pub fn select_best_resolution(
|
||||
original_size: (u32, u32),
|
||||
possible_resolutions: &[(u32, u32)],
|
||||
) -> (u32, u32) {
|
||||
let (original_width, original_height) = original_size;
|
||||
let mut best_fit = (0, 0);
|
||||
let original_width_f = original_width as f32;
|
||||
let original_height_f = original_height as f32;
|
||||
let mut max_effective_resolution = 0_u32;
|
||||
let mut min_wasted_resolution = u32::MAX;
|
||||
for (width, height) in possible_resolutions {
|
||||
let width_f = *width as f32;
|
||||
let height_f = *height as f32;
|
||||
let scale = (width_f / original_width_f).min(height_f / original_height_f);
|
||||
let (downscaled_width, downscaled_height) = (
|
||||
(original_width_f * scale) as u32,
|
||||
(original_height_f * scale) as u32,
|
||||
);
|
||||
let effective_resolution =
|
||||
std::cmp::min((*width) * (*height), downscaled_width * downscaled_height);
|
||||
let wasted_resolution = (*width) * (*height) - effective_resolution;
|
||||
if effective_resolution > max_effective_resolution
|
||||
|| (effective_resolution == max_effective_resolution
|
||||
&& wasted_resolution < min_wasted_resolution)
|
||||
{
|
||||
best_fit = (*width, *height);
|
||||
max_effective_resolution = effective_resolution;
|
||||
min_wasted_resolution = wasted_resolution;
|
||||
}
|
||||
}
|
||||
best_fit
|
||||
}
|
800
candle-transformers/src/models/mobilenetv4.rs
Normal file
800
candle-transformers/src/models/mobilenetv4.rs
Normal file
@ -0,0 +1,800 @@
|
||||
//! MobileNet-v4 inference implementation based on timm.
|
||||
//!
|
||||
//! See "MobileNetV4 - Universal Models for the Mobile Ecosystem"
|
||||
//! https://arxiv.org/abs/2404.10518
|
||||
//!
|
||||
//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilenetv3.py
|
||||
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d_no_bias, linear, ops::softmax, Activation, Conv2dConfig, Func, VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum BlockType {
|
||||
Convolutional {
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
},
|
||||
UniversalBottleneck {
|
||||
out_channels: usize,
|
||||
start_kernel: usize,
|
||||
mid_kernel: usize,
|
||||
stride: usize,
|
||||
expand: usize,
|
||||
},
|
||||
EdgeResidual {
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
expand: usize,
|
||||
},
|
||||
Attention {
|
||||
out_channels: usize,
|
||||
heads: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
kv_dim: usize,
|
||||
kv_stride: usize,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
stem_dim: usize,
|
||||
activation: Activation,
|
||||
stages: [Vec<BlockType>; 5],
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
impl Config {
|
||||
pub fn small() -> Self {
|
||||
Self {
|
||||
stem_dim: 32,
|
||||
activation: Activation::Relu,
|
||||
stages: [
|
||||
vec![
|
||||
BlockType::Convolutional { out_channels: 32, kernel: 3, stride: 2},
|
||||
BlockType::Convolutional { out_channels: 32, kernel: 1, stride: 1},
|
||||
],
|
||||
vec![
|
||||
BlockType::Convolutional { out_channels: 96, kernel: 3, stride: 2},
|
||||
BlockType::Convolutional { out_channels: 64, kernel: 1, stride: 1},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 3},
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 3, mid_kernel: 3, stride: 2, expand: 6},
|
||||
BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 5, stride: 1, expand: 3},
|
||||
BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 128, start_kernel: 0, mid_kernel: 3, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
|
||||
],
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn medium() -> Self {
|
||||
Self {
|
||||
stem_dim: 32,
|
||||
activation: Activation::Relu,
|
||||
stages: [
|
||||
vec![
|
||||
BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 2},
|
||||
|
||||
],
|
||||
vec![
|
||||
BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
|
||||
],
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hybrid_medium() -> Self {
|
||||
Self {
|
||||
stem_dim: 32,
|
||||
activation: Activation::Relu,
|
||||
stages: [
|
||||
vec![
|
||||
BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 80, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 2},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 6},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 160, heads: 4, kernel: 3, stride: 1, kv_stride:2, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 160, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 6},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 2},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 0, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 256, heads: 4, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 256, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
|
||||
],
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn large() -> Self {
|
||||
Self {
|
||||
stem_dim: 24,
|
||||
activation: Activation::Relu,
|
||||
stages: [
|
||||
vec![
|
||||
BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
|
||||
],
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hybrid_large() -> Self {
|
||||
Self {
|
||||
stem_dim: 24,
|
||||
activation: Activation::Gelu,
|
||||
stages: [
|
||||
vec![
|
||||
BlockType::EdgeResidual { out_channels: 48, kernel: 3, stride: 2, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 96, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 192, heads: 8, kernel: 3, stride: 1, kv_stride:2, kv_dim: 48},
|
||||
BlockType::UniversalBottleneck { out_channels: 192, start_kernel: 3, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
|
||||
vec![
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 2, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 3, stride: 1, expand: 4},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 5, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
BlockType::Attention { out_channels: 512, heads: 8, kernel: 3, stride: 1, kv_stride:1, kv_dim: 64},
|
||||
BlockType::UniversalBottleneck { out_channels: 512, start_kernel: 5, mid_kernel: 0, stride: 1, expand: 4},
|
||||
],
|
||||
vec![
|
||||
BlockType::Convolutional { out_channels: 960, kernel: 1, stride: 1},
|
||||
],
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn depthwise_conv(
|
||||
channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
groups: channels,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||
}
|
||||
|
||||
fn pointwise_conv(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||
}
|
||||
|
||||
//Universal block that uses two pointwise convolutions and all combinations of two depthwise convolutions.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn universal_inverted_bottleneck_block(
|
||||
cfg: &Config,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
expand: usize,
|
||||
start_kernel: usize,
|
||||
mid_kernel: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let act = cfg.activation;
|
||||
let skip_connection = (in_channels == out_channels) && (stride == 1);
|
||||
|
||||
let dw_start_stride = if mid_kernel > 0 { 1 } else { stride };
|
||||
let dw_start = depthwise_conv(
|
||||
in_channels,
|
||||
start_kernel,
|
||||
dw_start_stride,
|
||||
start_kernel / 2,
|
||||
vb.pp("dw_start"),
|
||||
);
|
||||
let pw_exp = pointwise_conv(in_channels, in_channels * expand, vb.pp("pw_exp"))?;
|
||||
let dw_mid = depthwise_conv(
|
||||
in_channels * expand,
|
||||
mid_kernel,
|
||||
stride,
|
||||
mid_kernel / 2,
|
||||
vb.pp("dw_mid"),
|
||||
);
|
||||
let pw_proj = pointwise_conv(in_channels * expand, out_channels, vb.pp("pw_proj"))?;
|
||||
|
||||
let gamma = vb.get(out_channels, "layer_scale.gamma");
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs.clone();
|
||||
|
||||
let mut xs = xs.clone();
|
||||
|
||||
if let Ok(f) = &dw_start {
|
||||
xs = xs.apply(f)?;
|
||||
}
|
||||
|
||||
xs = xs.apply(&pw_exp)?.apply(&act)?;
|
||||
|
||||
if let Ok(f) = &dw_mid {
|
||||
xs = xs.apply(f)?.apply(&act)?;
|
||||
}
|
||||
|
||||
xs = xs.apply(&pw_proj)?;
|
||||
|
||||
if let Ok(g) = &gamma {
|
||||
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
|
||||
};
|
||||
|
||||
if skip_connection {
|
||||
xs = (xs + residual)?;
|
||||
}
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Convolutional block including norm and activation.
|
||||
fn conv_block(
|
||||
cfg: &Config,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding: kernel / 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let act = cfg.activation;
|
||||
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?;
|
||||
let conv = conv2d_no_bias(in_channels, out_channels, kernel, conv2d_cfg, vb.pp("conv"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act)
|
||||
}))
|
||||
}
|
||||
|
||||
fn edge_residual_block(
|
||||
cfg: &Config,
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
expand: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv_exp_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding: kernel / 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let conv_pwl_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let act = cfg.activation;
|
||||
let mid_channels = in_channels * expand;
|
||||
let conv_exp = conv2d_no_bias(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
kernel,
|
||||
conv_exp_cfg,
|
||||
vb.pp("conv_exp"),
|
||||
)?;
|
||||
let bn1 = batch_norm(mid_channels, 1e-5, vb.pp("bn1"))?;
|
||||
|
||||
let conv_pwl = conv2d_no_bias(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
1,
|
||||
conv_pwl_cfg,
|
||||
vb.pp("conv_pwl"),
|
||||
)?;
|
||||
let bn2 = batch_norm(out_channels, 1e-5, vb.pp("bn2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv_exp)?
|
||||
.apply_t(&bn1, false)?
|
||||
.apply(&act)?
|
||||
.apply(&conv_pwl)?
|
||||
.apply_t(&bn2, false)?;
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn reshape_kv(t: &Tensor) -> Result<Tensor> {
|
||||
let d = t.dims4()?;
|
||||
let t = t
|
||||
.reshape((d.0, d.1, ()))?
|
||||
.transpose(1, 2)?
|
||||
.unsqueeze(1)?
|
||||
.contiguous()?;
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
fn reshape_query(t: &Tensor, heads: usize, kv_dim: usize) -> Result<Tensor> {
|
||||
let d = t.dims4()?;
|
||||
|
||||
let t = t
|
||||
.reshape((d.0, heads, kv_dim, ()))?
|
||||
.transpose(D::Minus1, D::Minus2)?
|
||||
.contiguous()?;
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
fn reshape_output(t: &Tensor, heads: usize, h: usize, w: usize) -> Result<Tensor> {
|
||||
let d = t.dims4()?;
|
||||
let t = t.transpose(1, 2)?;
|
||||
let t = t
|
||||
.reshape((d.0, h, w, d.3 * heads))?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.contiguous()?;
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
// Mobile multi-query attention
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mqa_block(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
heads: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
kv_dim: usize,
|
||||
kv_stride: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let down_conv2d_cfg = Conv2dConfig {
|
||||
stride: kv_stride,
|
||||
padding: kernel / 2,
|
||||
groups: in_channels,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let proj_conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let skip_connection = (in_channels == out_channels) && (stride == 1);
|
||||
let gamma = vb.get(out_channels, "layer_scale.gamma");
|
||||
let norm = batch_norm(out_channels, 1e-5, vb.pp("norm"))?;
|
||||
let scale = (kv_dim as f64).powf(-0.5);
|
||||
|
||||
let vb = vb.pp("attn");
|
||||
|
||||
let query_proj = conv2d_no_bias(
|
||||
out_channels,
|
||||
kv_dim * heads,
|
||||
1,
|
||||
proj_conv2d_cfg,
|
||||
vb.pp("query.proj"),
|
||||
)?;
|
||||
|
||||
let key_down_conv = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
down_conv2d_cfg,
|
||||
vb.pp("key.down_conv"),
|
||||
);
|
||||
let key_norm = batch_norm(out_channels, 1e-5, vb.pp("key.norm"));
|
||||
|
||||
let key_proj = conv2d_no_bias(out_channels, kv_dim, 1, proj_conv2d_cfg, vb.pp("key.proj"))?;
|
||||
|
||||
let value_down_conv = conv2d_no_bias(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
down_conv2d_cfg,
|
||||
vb.pp("value.down_conv"),
|
||||
);
|
||||
|
||||
let value_norm = batch_norm(out_channels, 1e-5, vb.pp("value.norm"));
|
||||
let value_proj = conv2d_no_bias(
|
||||
out_channels,
|
||||
kv_dim,
|
||||
1,
|
||||
proj_conv2d_cfg,
|
||||
vb.pp("value.proj"),
|
||||
)?;
|
||||
|
||||
let output_proj = conv2d_no_bias(
|
||||
kv_dim * heads,
|
||||
out_channels,
|
||||
1,
|
||||
proj_conv2d_cfg,
|
||||
vb.pp("output.proj"),
|
||||
)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let (_, _, h, w) = xs.dims4()?;
|
||||
|
||||
let residual = xs.clone();
|
||||
|
||||
let xs = xs.apply_t(&norm, false)?;
|
||||
|
||||
// Query
|
||||
let q = xs.apply(&query_proj)?;
|
||||
|
||||
let q = reshape_query(&q, heads, kv_dim)?;
|
||||
let q = (q * scale)?;
|
||||
|
||||
// Keys
|
||||
let mut k = xs.clone();
|
||||
|
||||
if let (Ok(kd), Ok(n)) = (&key_down_conv, &key_norm) {
|
||||
k = k.apply(kd)?.apply_t(n, false)?;
|
||||
}
|
||||
|
||||
let k = k.apply(&key_proj)?;
|
||||
|
||||
let k = reshape_kv(&k)?;
|
||||
|
||||
// Value
|
||||
let mut v = xs.clone();
|
||||
|
||||
if let (Ok(vd), Ok(n)) = (&value_down_conv, &value_norm) {
|
||||
v = v.apply(vd)?;
|
||||
v = v.apply_t(n, false)?;
|
||||
}
|
||||
|
||||
let v = v.apply(&value_proj)?;
|
||||
let v = reshape_kv(&v)?;
|
||||
|
||||
let attn = q.broadcast_matmul(&(k.transpose(D::Minus2, D::Minus1)?))?;
|
||||
let attn = softmax(&attn, D::Minus1)?;
|
||||
let o = attn.broadcast_matmul(&v)?;
|
||||
|
||||
let o = reshape_output(&o, heads, h, w)?;
|
||||
|
||||
let mut xs = o.apply(&output_proj)?;
|
||||
|
||||
// Layer scale
|
||||
|
||||
if let Ok(g) = &gamma {
|
||||
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
|
||||
};
|
||||
|
||||
if skip_connection {
|
||||
xs = (xs + residual)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Stem.
|
||||
fn mobilenetv4_stem(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride: 2,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let act = cfg.activation;
|
||||
let out_channels = cfg.stem_dim;
|
||||
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn1"))?;
|
||||
let conv = conv2d_no_bias(3, out_channels, 3, conv2d_cfg, vb.pp("conv_stem"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&conv)?.apply_t(&bn, false)?.apply(&act)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// The blocks in all the 5 stages of the model.
|
||||
fn mobilenetv4_blocks(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let mut in_channels = cfg.stem_dim;
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
for stage in 0..5 {
|
||||
let nblocks = cfg.stages[stage].len();
|
||||
|
||||
for block in 0..nblocks {
|
||||
match cfg.stages[stage][block] {
|
||||
BlockType::Convolutional {
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
} => {
|
||||
blocks.push(conv_block(
|
||||
cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
vb.pp(format!("{stage}.{block}")),
|
||||
)?);
|
||||
in_channels = out_channels;
|
||||
}
|
||||
|
||||
BlockType::EdgeResidual {
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
expand,
|
||||
} => {
|
||||
blocks.push(edge_residual_block(
|
||||
cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel,
|
||||
stride,
|
||||
expand,
|
||||
vb.pp(format!("{stage}.{block}")),
|
||||
)?);
|
||||
in_channels = out_channels;
|
||||
}
|
||||
|
||||
BlockType::UniversalBottleneck {
|
||||
out_channels,
|
||||
start_kernel,
|
||||
mid_kernel,
|
||||
stride,
|
||||
expand,
|
||||
} => {
|
||||
blocks.push(universal_inverted_bottleneck_block(
|
||||
cfg,
|
||||
in_channels,
|
||||
out_channels,
|
||||
expand,
|
||||
start_kernel,
|
||||
mid_kernel,
|
||||
stride,
|
||||
vb.pp(format!("{stage}.{block}")),
|
||||
)?);
|
||||
in_channels = out_channels;
|
||||
}
|
||||
|
||||
BlockType::Attention {
|
||||
out_channels,
|
||||
heads,
|
||||
kernel,
|
||||
stride,
|
||||
kv_dim,
|
||||
kv_stride,
|
||||
} => {
|
||||
blocks.push(mqa_block(
|
||||
in_channels,
|
||||
out_channels,
|
||||
heads,
|
||||
kernel,
|
||||
stride,
|
||||
kv_dim,
|
||||
kv_stride,
|
||||
vb.pp(format!("{stage}.{block}")),
|
||||
)?);
|
||||
in_channels = out_channels;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Classification head.
|
||||
fn mobilenetv4_head(
|
||||
cfg: &Config,
|
||||
outputs: usize,
|
||||
nclasses: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let act = cfg.activation;
|
||||
let conv = conv2d_no_bias(960, outputs, 1, conv2d_cfg, vb.pp("conv_head"))?;
|
||||
let norm = batch_norm(outputs, 1e-5, vb.pp("norm_head"))?;
|
||||
let cls = linear(outputs, nclasses, vb.pp("classifier"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
xs = xs.apply(&conv)?;
|
||||
xs = xs.apply_t(&norm, false)?.apply(&act)?;
|
||||
xs = xs.flatten_from(1)?;
|
||||
xs = xs.apply(&cls)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a mobilenetv4 model for a given configuration.
|
||||
fn mobilenetv4_model(
|
||||
cfg: &Config,
|
||||
nclasses: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = 1280;
|
||||
let head = mobilenetv4_head(cfg, outputs, nclasses, vb.clone())?;
|
||||
Some(head)
|
||||
}
|
||||
};
|
||||
|
||||
let stem = mobilenetv4_stem(cfg, vb.clone())?;
|
||||
|
||||
let blocks = mobilenetv4_blocks(cfg, vb.pp("blocks"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&stem)?.apply(&blocks)?;
|
||||
let xs = xs.mean_keepdim(D::Minus1)?.mean_keepdim(D::Minus2)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn mobilenetv4(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobilenetv4_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn mobilenetv4_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
mobilenetv4_model(cfg, None, vb)
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
pub mod beit;
|
||||
pub mod bert;
|
||||
pub mod bigcode;
|
||||
pub mod blip;
|
||||
@ -6,23 +7,28 @@ pub mod chatglm;
|
||||
pub mod clip;
|
||||
pub mod convmixer;
|
||||
pub mod convnext;
|
||||
pub mod depth_anything_v2;
|
||||
pub mod dinov2;
|
||||
pub mod dinov2reg4;
|
||||
pub mod distilbert;
|
||||
pub mod efficientnet;
|
||||
pub mod efficientvit;
|
||||
pub mod encodec;
|
||||
pub mod eva2;
|
||||
pub mod falcon;
|
||||
pub mod gemma;
|
||||
pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod llama2_c;
|
||||
pub mod llama2_c_weights;
|
||||
pub mod llava;
|
||||
pub mod mamba;
|
||||
pub mod marian;
|
||||
pub mod metavoice;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
pub mod mobilenetv4;
|
||||
pub mod mobileone;
|
||||
pub mod moondream;
|
||||
pub mod mpt;
|
||||
|
@ -3,6 +3,7 @@ use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub phi_config: PhiConfig,
|
||||
pub vision_config: VisionConfig,
|
||||
|
@ -56,24 +56,20 @@ impl RotaryEmbedding {
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.max_position_embeddings, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
dim,
|
||||
sin: emb.sin()?,
|
||||
cos: emb.cos()?,
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
|
||||
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
|
||||
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||
}
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ impl LayerWeights {
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?
|
||||
att.matmul(&v)?
|
||||
};
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.attn_output.forward(&y)?;
|
||||
@ -203,7 +203,6 @@ fn precomput_freqs_cis(
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
batch_size: usize,
|
||||
use_flash_attn: bool,
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
@ -252,12 +251,7 @@ impl ModelWeights {
|
||||
)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let kv_cache = KvCache::new(
|
||||
2,
|
||||
(batch_size, head_count_kv, max_seq_len, head_dim),
|
||||
DType::F32,
|
||||
device,
|
||||
)?;
|
||||
let kv_cache = KvCache::new(2, max_seq_len);
|
||||
layers.push(LayerWeights {
|
||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||
|
@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
|
||||
|
||||
impl ModelForCausalLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let base_model = Model::new(cfg, vb)?;
|
||||
let base_model = Model::new(cfg, vb.clone())?;
|
||||
let lm_head = if vb.contains_tensor("lm_head") {
|
||||
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
} else {
|
||||
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
|
||||
};
|
||||
Ok(Self {
|
||||
base_model,
|
||||
lm_head,
|
||||
|
@ -54,8 +54,7 @@ impl ModuleT for Vgg<'_> {
|
||||
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
|
||||
let layers = convs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(_, &(in_c, out_c, name))| {
|
||||
.map(|&(in_c, out_c, name)| {
|
||||
candle_nn::conv2d(
|
||||
in_c,
|
||||
out_c,
|
||||
|
Reference in New Issue
Block a user