Compare commits

..

11 Commits

Author SHA1 Message Date
a3dd87f15e Adding Gemm and ArgMax operators to candle-onnx (#2231)
* feat(gemm): implement Gemm operator in candle-onnx

* feat(onnx): Add support for ArgMax operator in candle-onnx

* Apply rustfmt.

* Remove argmax as it was already present.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-28 21:40:31 +02:00
242e006bbb Depth Anything v2 (#2279)
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-06-24 19:12:52 +02:00
6baa1d486b Fix a bug in the metal implemtation of col2im1d. (#2284) 2024-06-22 23:21:20 +02:00
36cf54525d Fix the fast bf16 gemm cublas kernels. (#2274)
* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
2024-06-18 23:46:58 +02:00
2b10aaa05d implement Slice op (#2260) 2024-06-12 07:15:32 +01:00
9f804af29d feat(ci): add trufflehog secrets detection (#2262)
* feat(ci): add trufflehog secrets detection

* fix(ci): remove unnecessary permissions
2024-06-10 21:03:54 +01:00
54ff971e35 Support for the new Qwen2 models. (#2257)
* Support for the new Qwen2 models.

* Add more models.
2024-06-07 10:51:50 +01:00
b9fac7ec00 implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode

The intent of this change is to allow eval of the current silero_vad.onnx (v4).
This onnx file uses 'If' and 'Pad' nodes, which had not been supported
by simple_eval until now

* Cleanup (fmt, clippy, minor test tweaks).

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-06-06 22:36:23 +02:00
f65e90e7ef Bump the crate version. (#2248) 2024-06-05 15:49:15 +02:00
d39462856b Apply rustfmt. (#2247) 2024-06-04 22:54:09 +02:00
cb180eb23a ONNX: add ArgMin, ArgMax and LeakyRelu (#2246)
* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

* Added ArgMin operator implementation

* Added tests for ArgMin

* ArgMin now returns a tensor with i64

* Added tests from pytorch examples

* Added ArgMax operator implementation

* Added tests for ArgMax

* Added LeakyRelu implementation

* Added a test for LeakyRelu

* Typo fix

* Fix a weird automatic RustRover change

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
2024-06-04 22:49:02 +02:00
24 changed files with 1515 additions and 98 deletions

15
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,15 @@
on:
push:
name: Secret Leaks
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.5.1" version = "0.6.0"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" } candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
candle-datasets = { path = "./candle-datasets", version = "0.5.1" } candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" } candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
candle-kernels = { path = "./candle-kernels", version = "0.5.1" } candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" } candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
candle-nn = { path = "./candle-nn", version = "0.5.1" } candle-nn = { path = "./candle-nn", version = "0.6.0" }
candle-onnx = { path = "./candle-onnx", version = "0.5.1" } candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
candle-transformers = { path = "./candle-transformers", version = "0.5.1" } candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }

View File

@ -9,8 +9,10 @@ use candle_core::{Device, Tensor};
fn main() -> Result<()> { fn main() -> Result<()> {
let device = Device::new_cuda(0)?; let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?; let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false); candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?; let _x1 = x.matmul(&x)?;
drop(_x1); drop(_x1);
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
@ -19,6 +21,7 @@ fn main() -> Result<()> {
println!("fp32: {:?}", start_time.elapsed()); println!("fp32: {:?}", start_time.elapsed());
drop(_x1); drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true); candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?; let _x1 = x.matmul(&x)?;
drop(_x1); drop(_x1);
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();

View File

@ -121,7 +121,8 @@ impl ReduceIndex {
let dst_len = src_l.shape().elem_count() / reduce_dim_size; let dst_len = src_l.shape().elem_count() / reduce_dim_size;
let mut dst: Vec<U> = Vec::with_capacity(dst_len); let mut dst: Vec<U> = Vec::with_capacity(dst_len);
let dst_to_set = dst.spare_capacity_mut(); let dst_to_set = dst.spare_capacity_mut();
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) }; let dst_to_set =
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
match src_l.contiguous_offsets() { match src_l.contiguous_offsets() {
Some((o1, o2)) => { Some((o1, o2)) => {
let src = &src[o1..o2]; let src = &src[o1..o2];

View File

@ -174,7 +174,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
let mut ys: Vec<T> = Vec::with_capacity(el_count); let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
// SAFETY: values are all set by f_vec. // SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) }; unsafe { ys.set_len(el_count) };
@ -185,7 +187,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let rhs = &rhs[ob.start..ob.start + ob.len]; let rhs = &rhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count); let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0; let mut dst_i = 0;
for src_i in (o_l1..o_l2).step_by(ob.len) { for src_i in (o_l1..o_l2).step_by(ob.len) {
f_vec( f_vec(
@ -224,7 +228,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
let lhs = &lhs[ob.start..ob.start + ob.len]; let lhs = &lhs[ob.start..ob.start + ob.len];
let mut ys: Vec<T> = Vec::with_capacity(el_count); let mut ys: Vec<T> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
};
let mut dst_i = 0; let mut dst_i = 0;
for src_i in (o_r1..o_r2).step_by(ob.len) { for src_i in (o_r1..o_r2).step_by(ob.len) {
f_vec( f_vec(
@ -311,7 +317,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
crate::StridedBlocks::SingleBlock { start_offset, len } => { crate::StridedBlocks::SingleBlock { start_offset, len } => {
let mut ys: Vec<U> = Vec::with_capacity(len); let mut ys: Vec<U> = Vec::with_capacity(len);
let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
f_vec(&vs[start_offset..start_offset + len], ys_to_set); f_vec(&vs[start_offset..start_offset + len], ys_to_set);
// SAFETY: values are all set by f_vec. // SAFETY: values are all set by f_vec.
unsafe { ys.set_len(len) }; unsafe { ys.set_len(len) };
@ -333,7 +341,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
} else { } else {
let mut ys: Vec<U> = Vec::with_capacity(el_count); let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; let ys_to_set = unsafe {
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
};
let mut dst_index = 0; let mut dst_index = 0;
for src_index in block_start_index { for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len]; let vs = &vs[src_index..src_index + block_len];

View File

@ -2035,15 +2035,13 @@ unsafe fn gemm_strided_batched_bf16(
let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
let beta_f32: f32 = cfg.gemm.beta.to_f32(); let beta_f32: f32 = cfg.gemm.beta.to_f32();
let alpha = f16::from_f32(alpha_f32);
let beta = f16::from_f32(beta_f32);
// The type for alpha and beta depends on the computeType. // The type for alpha and beta depends on the computeType.
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() { let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
( (
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
(&alpha) as *const f16 as *const _, (&alpha_f32) as *const f32 as *const _,
(&beta) as *const f16 as *const _, (&beta_f32) as *const f32 as *const _,
) )
} else { } else {
( (

View File

@ -848,7 +848,6 @@ impl BackendStorage for MetalStorage {
.device .device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?; .new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "col2im1d_f32", DType::F32 => "col2im1d_f32",
DType::U32 => "col2im1d_u32", DType::U32 => "col2im1d_u32",
@ -869,6 +868,12 @@ impl BackendStorage for MetalStorage {
&kernel_l_mm, &kernel_l_mm,
)? )?
}; };
// It is important for the command buffer to be obtained *after* the matmul
// kernel has run, otherwise we might use a command-buffer that has been commited
// already resulting in the following error.
// _status < MTLCommandBufferStatusCommitted >
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_col2im1d( candle_metal_kernels::call_col2im1d(
&self.device.device, &self.device.device,
&command_buffer, &command_buffer,

View File

@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true } image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true } rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true } rubato = { version = "0.15.0", optional = true }
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"] metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"] microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"] encodec = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
[[example]] [[example]]
name = "llama_multiprocess" name = "llama_multiprocess"
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
[[example]] [[example]]
name = "encodec" name = "encodec"
required-features = ["encodec"] required-features = ["encodec"]
[[example]]
name = "depth_anything_v2"
required-features = ["depth_anything_v2"]

View File

@ -0,0 +1,13 @@
# candle-dinov2
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
## Running an example with color map and CUDA
```bash
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

View File

@ -0,0 +1,50 @@
use enterpolation::linear::ConstEquidistantLinear;
use enterpolation::Generator;
use palette::LinSrgb;
use candle::Tensor;
pub struct SpectralRColormap {
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
}
impl SpectralRColormap {
pub(crate) fn new() -> Self {
// Define a colormap similar to 'Spectral_r' by specifying key colors.
// got the colors from ChatGPT-4o
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
]);
Self { gradient }
}
fn get_color(&self, value: f32) -> LinSrgb {
self.gradient.gen(value)
}
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
println!("Gray: {:?}", gray.dims());
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
let rgb_values: Vec<f32> = gray_values
.iter()
.map(|g| self.get_color(*g))
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
.collect();
let [.., height, width] = gray.dims() else {
candle::bail!("Not enough dims!")
};
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
color.permute((2, 0, 1))
}
}

View File

@ -0,0 +1,187 @@
//! Depth Anything V2
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use std::ffi::OsString;
use std::path::PathBuf;
use clap::Parser;
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
use candle_examples::{load_image, load_image_and_resize, save_image};
use candle_nn::VarBuilder;
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
use candle_transformers::models::dinov2;
use crate::color_map::SpectralRColormap;
mod color_map;
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
const DINO_IMG_SIZE: usize = 518;
#[derive(Parser)]
struct Args {
#[arg(long)]
dinov2_model: Option<PathBuf>,
#[arg(long)]
depth_anything_v2_model: Option<PathBuf>,
#[arg(long)]
image: PathBuf,
#[arg(long)]
output_dir: Option<PathBuf>,
#[arg(long)]
cpu: bool,
#[arg(long)]
color_map: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let dinov2_model_file = match args.dinov2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-dino-v2".into());
api.get("dinov2_vits14.safetensors")?
}
Some(dinov2_model) => dinov2_model,
};
println!("Using file {:?}", dinov2_model_file);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
let dinov2 = dinov2::vit_small(vb)?;
println!("DinoV2 model built");
let depth_anything_model_file = match args.depth_anything_v2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
api.get("depth_anything_v2_vits.safetensors")?
}
Some(depth_anything_model) => depth_anything_model,
};
println!("Using file {:?}", depth_anything_model_file);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
};
let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
println!("Loaded image {image:?}");
let depth = depth_anything.forward(&image)?;
println!("Got predictions {:?}", depth.shape());
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
let output_path = full_output_path(&args.image, &args.output_dir);
println!("Saving image to {}", output_path.to_string_lossy());
save_image(&output_image, output_path)?;
Ok(())
}
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
let input_file_name = image_path.file_name().unwrap();
let mut output_file_name = OsString::from("depth_");
output_file_name.push(input_file_name);
let mut output_path = match output_dir {
None => image_path.parent().unwrap().to_path_buf(),
Some(output_path) => output_path.clone(),
};
output_path.push(output_file_name);
output_path
}
fn load_and_prep_image(
image_path: &PathBuf,
device: &Device,
) -> anyhow::Result<(usize, usize, Tensor)> {
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
.unsqueeze(0)?
.to_dtype(F32)?
.to_device(&device)?;
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(&device)?
.broadcast_as(image.shape())?;
let image = (image / max_pixel_val)?;
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
Ok((original_height, original_width, image))
}
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
let mean_tensor =
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
let std_tensor =
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
image.sub(&mean_tensor)?.div(&std_tensor)
}
fn post_process_image(
image: &Tensor,
original_height: usize,
original_width: usize,
color_map: bool,
) -> Result<Tensor> {
let out = image.interpolate2d(original_height, original_width)?;
let out = scale_image(&out)?;
let out = if color_map {
let spectral_r = SpectralRColormap::new();
spectral_r.gray2color(&out)?
} else {
let rgb_slice = [&out, &out, &out];
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
};
let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(out.device())?
.broadcast_as(out.shape())?;
let out = (out * max_pixel_val)?;
out.to_dtype(U8)
}
fn scale_image(depth: &Tensor) -> Result<Tensor> {
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
let min_val_tensor = Tensor::try_from(*min_val)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
let depth = (depth - min_val_tensor)?;
let range = max_val - min_val;
let range_tensor = Tensor::try_from(range)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
depth / range_tensor
}

View File

@ -144,6 +144,14 @@ enum WhichModel {
W72b, W72b,
#[value(name = "moe-a2.7b")] #[value(name = "moe-a2.7b")]
MoeA27b, MoeA27b,
#[value(name = "2-0.5b")]
W2_0_5b,
#[value(name = "2-1.5b")]
W2_1_5b,
#[value(name = "2-7b")]
W2_7b,
#[value(name = "2-72b")]
W2_72b,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -234,16 +242,20 @@ fn main() -> Result<()> {
let model_id = match args.model_id { let model_id = match args.model_id {
Some(model_id) => model_id, Some(model_id) => model_id,
None => { None => {
let size = match args.model { let (version, size) = match args.model {
WhichModel::W0_5b => "0.5B", WhichModel::W2_0_5b => ("2", "0.5B"),
WhichModel::W1_8b => "1.8B", WhichModel::W2_1_5b => ("2", "1.5B"),
WhichModel::W4b => "4B", WhichModel::W2_7b => ("2", "7B"),
WhichModel::W7b => "7B", WhichModel::W2_72b => ("2", "72B"),
WhichModel::W14b => "14B", WhichModel::W0_5b => ("1.5", "0.5B"),
WhichModel::W72b => "72B", WhichModel::W1_8b => ("1.5", "1.8B"),
WhichModel::MoeA27b => "MoE-A2.7B", WhichModel::W4b => ("1.5", "4B"),
WhichModel::W7b => ("1.5", "7B"),
WhichModel::W14b => ("1.5", "14B"),
WhichModel::W72b => ("1.5", "72B"),
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
}; };
format!("Qwen/Qwen1.5-{size}") format!("Qwen/Qwen{version}-{size}")
} }
}; };
let repo = api.repo(Repo::with_revision( let repo = api.repo(Repo::with_revision(
@ -261,11 +273,15 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from) .map(std::path::PathBuf::from)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
None => match args.model { None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?], WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
vec![repo.get("model.safetensors")?]
}
WhichModel::W4b WhichModel::W4b
| WhichModel::W7b | WhichModel::W7b
| WhichModel::W2_7b
| WhichModel::W14b | WhichModel::W14b
| WhichModel::W72b | WhichModel::W72b
| WhichModel::W2_72b
| WhichModel::MoeA27b => { | WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-flash-attn" name = "candle-flash-attn"
version = "0.5.1" version = "0.6.0"
edition = "2021" edition = "2021"
description = "Flash attention layer for the candle ML framework." description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.1" } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-kernels" name = "candle-kernels"
version = "0.5.1" version = "0.6.0"
edition = "2021" edition = "2021"
description = "CUDA kernels for Candle" description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.5.1" version = "0.6.0"
edition = "2021" edition = "2021"
description = "Metal kernels for Candle" description = "Metal kernels for Candle"

View File

@ -1,4 +1,4 @@
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D}; use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
use rayon::prelude::*; use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on /// Applies the softmax function to the input tensor, rescaling the element so that elements on
@ -926,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
n => candle::bail!("replication-pad with a size of {n} is not supported"), n => candle::bail!("replication-pad with a size of {n} is not supported"),
} }
} }
#[derive(Clone, Debug)]
pub struct Identity;
impl Identity {
pub fn new() -> Identity {
Self
}
}
impl Default for Identity {
fn default() -> Self {
Self
}
}
impl Module for Identity {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
Ok(xs.clone())
}
}

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-onnx" name = "candle-onnx"
version = "0.5.1" version = "0.6.0"
edition = "2021" edition = "2021"
description = "ONNX support for Candle" description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.5.1" } candle = { path = "../candle-core", package = "candle-core", version = "0.6.0" }
candle-nn = { path = "../candle-nn", version = "0.5.1" } candle-nn = { path = "../candle-nn", version = "0.6.0" }
prost = "0.12.1" prost = "0.12.1"
[build-dependencies] [build-dependencies]

View File

@ -1,6 +1,6 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType; use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType; use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor}; use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize}; use std::{collections::HashMap, usize};
@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
DataType::Float16 => Some(DType::F16), DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32), DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64), DataType::Double => Some(DType::F64),
DataType::Bool => Some(DType::U8),
_ => None, _ => None,
} }
} }
@ -56,6 +57,15 @@ impl Attr for str {
} }
} }
impl Attr for GraphProto {
const TYPE: AttributeType = AttributeType::Graph;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
attr.g
.as_ref()
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
}
}
impl AttrOwned for Tensor { impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor; const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> { fn get(attr: &onnx::AttributeProto) -> Result<Self> {
@ -214,13 +224,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
// anymore. // anymore.
pub fn simple_eval( pub fn simple_eval(
model: &onnx::ModelProto, model: &onnx::ModelProto,
inputs: HashMap<String, Value>, mut inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> { ) -> Result<HashMap<String, Value>> {
let graph = match &model.graph { let graph = match &model.graph {
None => bail!("no graph defined in proto"), None => bail!("no graph defined in proto"),
Some(graph) => graph, Some(graph) => graph,
}; };
let mut values = inputs; simple_eval_(graph, &mut inputs)
}
fn simple_eval_(
graph: &onnx::GraphProto,
values: &mut HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
for t in graph.initializer.iter() { for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?; let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor); values.insert(t.name.to_string(), tensor);
@ -958,6 +974,165 @@ pub fn simple_eval(
let input = get(&node.input[0])?; let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone()); values.insert(node.output[0].clone(), input.clone());
} }
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
"If" => {
// protobuf encodes boolean false as 0 and true as 1
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
let attr_name = if cond != 0 {
"then_branch"
} else {
"else_branch"
};
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
if sub_graph.output.len() != node.output.len() {
bail!(
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
node.name,
sub_graph.output.len(),
node.output.len()
);
}
let branch_out = simple_eval_(sub_graph, values)?;
for (i, out) in node.output.iter().enumerate() {
values.insert(
out.clone(),
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
);
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
"Pad" => {
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
let data = get(&node.input[0])?;
let pads = get(&node.input[1])?;
if node.input.len() > 2 {
bail!(
"unsupported number of inputs {} for Pad node {:?}, expected 2",
node.input.len(),
node.name
);
}
if pads.rank() != 1 {
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
}
if pads.dim(0).unwrap() != 2 * data.rank() {
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
}
let pads = pads.to_vec1::<i64>()?;
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
match mode {
"reflect" => {
let mut out = data.clone();
for (i, &dim) in data.dims().iter().enumerate().rev() {
if pads_pre[i] == 0 && pads_post[i] == 0 {
continue;
}
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
}
let idx = if dim > 1 {
let cycle_len = dim * 2 - 1;
let skip = (pads_pre[i] as usize) % cycle_len;
let idx = zigzag(0, (dim - 1) as i64)
.skip(skip)
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
Tensor::from_iter(idx, out.device())?
} else {
Tensor::full(0i64, (dim,), out.device())?
};
out = out.index_select(&idx, i)?;
}
values.insert(node.output[0].clone(), out);
}
_ => bail!(
"unsupported 'mode' value {mode:?} for Pad node {:?}",
node.name
),
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
"Slice" => {
let data = get(&node.input[0])?;
let starts = get(&node.input[1])?;
let ends = get(&node.input[2])?;
let default_axes;
let default_steps;
let axes: &Tensor;
let steps: &Tensor;
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
// they are set to [1, ..., 1] of length len(starts)
match node.input.len() {
3 => {
let len = starts.dims()[0];
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
axes = default_axes.as_ref().unwrap();
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
4 => {
let len = starts.dims()[0];
axes = get(&node.input[3])?;
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
5 => {
steps = get(&node.input[4])?;
axes = get(&node.input[3])?;
}
_ => bail!(
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
node.input.len(),
node
),
}
let mut out = data.clone();
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
// All negative elements of axes are made non-negative by
// adding r to them, where r = rank(input).
let axis = if axis < 0 {
axis + data.rank() as i64
} else {
axis
} as usize;
let data_dim = data.dims()[axis] as i64;
let mut s = starts.get(i)?.to_scalar::<i64>()?;
let mut e = ends.get(i)?.to_scalar::<i64>()?;
// All negative values in starts[i] and ends[i] have
// dims[axes[i]] added to them, where dims are the
// dimensions of input.
if s < 0 {
s += data_dim;
}
if e < 0 {
e += data_dim;
}
let p = steps.get(i)?.to_scalar::<i64>()?;
// starts[i] is clamped into the range [0, dims[axes[i]]]
// for positive stepping and [0, dims[axes[i]]-1] for
// negative stepping.
// for positive stepping ends[axes[i]] is clamped to
// [0, dims[axes[i]]], while for negative stepping it is
// clamped to [-1, dims[axes[i]]-1].
if p >= 0 {
s = s.clamp(0, data_dim);
e = e.clamp(0, data_dim);
} else {
s = s.clamp(0, data_dim - 1);
e = e.clamp(-1, data_dim - 1);
}
let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below. // TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => { "ReduceMean" => {
@ -1099,6 +1274,30 @@ pub fn simple_eval(
let output = candle_nn::ops::leaky_relu(input, alpha.into())?; let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
"Gemm" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let c = get(&node.input[2])?;
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(1.0);
let beta = get_attr_opt::<f32>(node, "beta")?.copied().unwrap_or(1.0);
let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?;
let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?;
let trans_a = get_attr_opt::<i64>(node, "transA")?.copied().unwrap_or(0);
let trans_b = get_attr_opt::<i64>(node, "transB")?.copied().unwrap_or(0);
let a = if trans_a == 0 { a.clone() } else { a.t()? };
let b = if trans_b == 0 { b.clone() } else { b.t()? };
let output = a
.broadcast_mul(&alpha)?
.broadcast_matmul(&b)?
.broadcast_add(&c.broadcast_mul(&beta)?)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"), op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
} }
} }

View File

@ -4,10 +4,12 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use candle::test_utils::to_vec2_round;
use candle::{DType, Device, NdArray, Result, Tensor}; use candle::{DType, Device, NdArray, Result, Tensor};
use candle_onnx::onnx;
use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap; use std::collections::HashMap;
@ -35,14 +37,11 @@ fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
#[test] #[test]
fn test_evaluation_fails_without_defined_graph() -> Result<()> { fn test_evaluation_fails_without_defined_graph() -> Result<()> {
let manual_graph = create_model_proto_with_graph(None); let manual_graph = create_model_proto_with_graph(None);
let inputs: HashMap<String, Tensor> = HashMap::new(); let inputs: HashMap<String, Tensor> = HashMap::new();
match candle_onnx::simple_eval(&manual_graph, inputs) { match candle_onnx::simple_eval(&manual_graph, inputs) {
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"), Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
Ok(_) => panic!("Expected an error due to undefined graph"), Ok(_) => panic!("Expected an error due to undefined graph"),
} }
Ok(()) Ok(())
} }
@ -81,14 +80,8 @@ fn test_add_operation() -> Result<()> {
assert_eq!(eval.len(), 1); assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z let first = z.to_vec1::<f64>()?[0];
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 4.0f64); assert_eq!(first, 4.0f64);
Ok(()) Ok(())
} }
@ -127,14 +120,8 @@ fn test_sub_operation() -> Result<()> {
assert_eq!(eval.len(), 1); assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z let first = z.to_vec1::<f64>()?[0];
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 0.0f64); assert_eq!(first, 0.0f64);
Ok(()) Ok(())
} }
@ -173,14 +160,8 @@ fn test_mul_operation() -> Result<()> {
assert_eq!(eval.len(), 1); assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z let first = z.to_vec1::<f64>()?[0];
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 4.0f64); assert_eq!(first, 4.0f64);
Ok(()) Ok(())
} }
@ -219,15 +200,8 @@ fn test_div_operation() -> Result<()> {
assert_eq!(eval.len(), 1); assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z let first = z.to_vec1::<f64>()?[0];
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
assert_eq!(first, 1.0f64); assert_eq!(first, 1.0f64);
Ok(()) Ok(())
} }
@ -272,7 +246,7 @@ fn test_exp_operation() -> Result<()> {
assert_eq!(results[0][0], 0.36787944f32); assert_eq!(results[0][0], 0.36787944f32);
assert_eq!(results[0][1], 1.0f32); assert_eq!(results[0][1], 1.0f32);
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]); assert_eq!(results[1], vec![std::f32::consts::E, 7.389056f32]);
Ok(()) Ok(())
} }
@ -914,7 +888,7 @@ fn test_constant_of_shape() -> Result<()> {
), ),
_ => panic!("unsupported DType in test"), _ => panic!("unsupported DType in test"),
}; };
let tensor = onnx::TensorProto { let tensor = TensorProto {
data_type: data_type.into(), data_type: data_type.into(),
dims: tensor.dims().iter().map(|v| *v as i64).collect(), dims: tensor.dims().iter().map(|v| *v as i64).collect(),
raw_data: value, raw_data: value,
@ -1293,14 +1267,7 @@ fn test_cos_operation() -> Result<()> {
assert_eq!(eval.len(), 1); assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
assert_eq!(to_vec2_round(z, 4)?, [[1.0, 0.5403], [-0.4161, -0.99]]);
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
);
Ok(()) Ok(())
} }
@ -1342,19 +1309,12 @@ fn test_sin_operation() -> Result<()> {
quantization_annotation: vec![], quantization_annotation: vec![],
})); }));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new(); let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x); inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1); assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
assert_eq!(to_vec2_round(z, 4)?, [[0.0, 0.8415], [0.9093, 0.1411]]);
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
Ok(()) Ok(())
} }
@ -3150,3 +3110,300 @@ fn test_leakyrelu() -> Result<()> {
Ok(()) Ok(())
} }
// "If"
#[test]
fn test_if() -> Result<()> {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let output_type_proto = Some(TypeProto {
value: Some(type_proto::Value::TensorType(type_proto::Tensor {
elem_type: DataType::Float.into(),
shape: Some(TensorShapeProto {
dim: vec![Dimension {
denotation: "".to_string(),
value: Some(dimension::Value::DimValue(5)),
}],
}),
})),
denotation: "".to_string(),
});
let then_branch = GraphProto {
output: vec![ValueInfoProto {
name: "then_out".to_string(),
r#type: output_type_proto.clone(),
doc_string: "".to_string(),
}],
node: vec![NodeProto {
op_type: "Constant".to_string(),
input: vec![],
output: vec!["then_out".to_string()],
attribute: vec![AttributeProto {
name: "value".to_string(),
r#type: AttributeType::Tensor.into(),
t: Some(TensorProto {
dims: vec![x.len() as i64],
float_data: x.clone(),
data_type: DataType::Float.into(),
..TensorProto::default()
}),
..AttributeProto::default()
}],
..NodeProto::default()
}],
..GraphProto::default()
};
let else_branch = GraphProto {
output: vec![ValueInfoProto {
name: "else_out".to_string(),
r#type: output_type_proto.clone(),
doc_string: "".to_string(),
}],
node: vec![NodeProto {
op_type: "Constant".to_string(),
input: vec![],
output: vec!["else_out".to_string()],
attribute: vec![AttributeProto {
name: "value".to_string(),
r#type: AttributeType::Tensor.into(),
t: Some(TensorProto {
dims: vec![y.len() as i64],
float_data: y.clone(),
data_type: DataType::Float.into(),
..TensorProto::default()
}),
..AttributeProto::default()
}],
..NodeProto::default()
}],
..GraphProto::default()
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "If".to_string(),
attribute: vec![
AttributeProto {
name: "then_branch".to_string(),
r#type: AttributeType::Graph.into(),
g: Some(then_branch),
..AttributeProto::default()
},
AttributeProto {
name: "else_branch".to_string(),
r#type: AttributeType::Graph.into(),
g: Some(else_branch),
..AttributeProto::default()
},
],
input: vec!["cond".to_string()],
output: vec!["res".to_string()],
..NodeProto::default()
}],
input: vec![],
output: vec![ValueInfoProto {
name: "res".to_string(),
doc_string: "".to_string(),
r#type: output_type_proto.clone(),
}],
..GraphProto::default()
}));
for cond in [1u8, 0] {
let inputs =
HashMap::from_iter([("cond".to_string(), Tensor::full(cond, (1,), &Device::Cpu)?)]);
let outputs = candle_onnx::simple_eval(&manual_graph, inputs)?;
let expected = if cond != 0 { &x } else { &y };
let Some(res) = outputs.get("res") else {
candle::bail!("outputs didn't contain expected key `res`: {outputs:?}");
};
assert_eq!(&res.to_vec1::<f32>()?, expected);
}
Ok(())
}
#[test]
fn test_pad() -> Result<()> {
let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?;
let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?;
let mode = "reflect";
let expected = Tensor::from_vec(
vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7],
(3, 4),
&Device::Cpu,
)?;
let model = create_model_proto_with_graph(Some(GraphProto {
input: vec![
ValueInfoProto {
name: "data".to_string(),
..ValueInfoProto::default()
},
ValueInfoProto {
name: "pads".to_string(),
..ValueInfoProto::default()
},
],
output: vec![ValueInfoProto {
name: "output".to_string(),
..ValueInfoProto::default()
}],
node: vec![NodeProto {
op_type: "Pad".to_string(),
input: vec!["data".to_string(), "pads".to_string()],
output: vec!["output".to_string()],
attribute: vec![AttributeProto {
name: "mode".to_string(),
r#type: AttributeType::String.into(),
s: mode.as_bytes().to_vec(),
..AttributeProto::default()
}],
..NodeProto::default()
}],
..GraphProto::default()
}));
let inputs = HashMap::from_iter([("data".to_string(), data), ("pads".to_string(), pads)]);
let res = candle_onnx::simple_eval(&model, inputs)?;
let Some(actual) = res.get("output") else {
candle::bail!("outputs didn't contain expected key `output`: {res:?}");
};
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
Ok(())
}
#[test]
fn test_slice() -> Result<()> {
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Slice".to_string(),
input: vec![
"data".to_string(),
"starts".to_string(),
"ends".to_string(),
"axes".to_string(),
"steps".to_string(),
],
output: vec!["result".to_string()],
..NodeProto::default()
}],
input: ["data", "starts", "ends", "axes", "steps"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
output: ["result"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
..GraphProto::default()
}));
/*
data = [
[1, 2, 3, 4],
[5, 6, 7, 8],
]
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
steps = [1, 2]
result = [
[5, 7],
]
*/
let outputs = candle_onnx::simple_eval(
&model,
HashMap::from_iter([
(
"data".to_string(),
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
),
(
"starts".to_string(),
Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?,
),
(
"ends".to_string(),
Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?,
),
(
"axes".to_string(),
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
),
(
"steps".to_string(),
Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?,
),
]),
)?;
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
assert_eq!(actual, vec![vec![5i64, 7]]);
/*
data = [
[1, 2, 3, 4],
[5, 6, 7, 8],
]
starts = [0, 1]
ends = [-1, 1000]
result = [
[2, 3, 4],
]
*/
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Slice".to_string(),
input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()],
output: vec!["result".to_string()],
..NodeProto::default()
}],
input: ["data", "starts", "ends"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
output: ["result"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
..GraphProto::default()
}));
let outputs = candle_onnx::simple_eval(
&model,
HashMap::from_iter([
(
"data".to_string(),
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
),
(
"starts".to_string(),
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
),
(
"ends".to_string(),
Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?,
),
]),
)?;
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
assert_eq!(actual, vec![vec![2i64, 3, 4]]);
Ok(())
}

View File

@ -0,0 +1,553 @@
use candle::D::Minus1;
use candle::{Module, Result, Tensor};
use candle_nn::ops::Identity;
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
};
use crate::models::dinov2::DinoVisionTransformer;
pub struct DepthAnythingV2Config {
out_channel_sizes: [usize; 4],
in_channel_size: usize, // embed_dim in the Dino model
num_features: usize,
use_batch_norm: bool,
use_class_token: bool,
layer_ids_vits: Vec<usize>,
input_image_size: usize,
target_patch_size: usize,
}
impl DepthAnythingV2Config {
#[allow(clippy::too_many_arguments)]
pub fn new(
out_channel_sizes: [usize; 4],
in_channel_size: usize,
num_features: usize,
use_batch_norm: bool,
use_class_token: bool,
layer_ids_vits: Vec<usize>,
input_image_size: usize,
target_patch_size: usize,
) -> Self {
Self {
out_channel_sizes,
in_channel_size,
num_features,
use_batch_norm,
use_class_token,
layer_ids_vits,
input_image_size,
target_patch_size,
}
}
pub fn vit_small() -> Self {
Self {
out_channel_sizes: [48, 96, 192, 384],
in_channel_size: 384,
num_features: 64,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![2, 5, 8, 11],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
pub fn vit_base() -> Self {
Self {
out_channel_sizes: [96, 192, 384, 768],
in_channel_size: 768,
num_features: 128,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![2, 5, 8, 11],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
pub fn vit_large() -> Self {
Self {
out_channel_sizes: [256, 512, 1024, 1024],
in_channel_size: 1024,
num_features: 256,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![4, 11, 17, 23],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
pub fn vit_giant() -> Self {
Self {
out_channel_sizes: [1536, 1536, 1536, 1536],
in_channel_size: 1536,
num_features: 384,
use_batch_norm: false,
use_class_token: false,
layer_ids_vits: vec![9, 19, 29, 39],
input_image_size: 518,
target_patch_size: 518 / 14,
}
}
}
pub struct ResidualConvUnit {
activation: Activation,
conv1: Conv2d,
conv2: Conv2d,
batch_norm1: Option<BatchNorm>,
batch_norm2: Option<BatchNorm>,
}
impl ResidualConvUnit {
pub fn new(
conf: &DepthAnythingV2Config,
activation: Activation,
vb: VarBuilder,
) -> Result<Self> {
const KERNEL_SIZE: usize = 3;
let conv_cfg = Conv2dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
};
let conv1 = conv2d(
conf.num_features,
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("conv1"),
)?;
let conv2 = conv2d(
conf.num_features,
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("conv2"),
)?;
let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
true => {
let batch_norm_cfg = BatchNormConfig {
eps: 1e-05,
remove_mean: false,
affine: true,
momentum: 0.1,
};
(
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
)
}
false => (None, None),
};
Ok(Self {
activation,
conv1,
conv2,
batch_norm1,
batch_norm2,
})
}
}
impl Module for ResidualConvUnit {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let out = self.activation.forward(xs)?;
let out = self.conv1.forward(&out)?;
let out = if let Some(batch_norm1) = &self.batch_norm1 {
batch_norm1.forward_train(&out)?
} else {
out
};
let out = self.activation.forward(&out)?;
let out = self.conv2.forward(&out)?;
let out = if let Some(batch_norm2) = &self.batch_norm2 {
batch_norm2.forward_train(&out)?
} else {
out
};
out + xs
}
}
pub struct FeatureFusionBlock {
res_conv_unit1: ResidualConvUnit,
res_conv_unit2: ResidualConvUnit,
output_conv: Conv2d,
target_patch_size: usize,
}
impl FeatureFusionBlock {
pub fn new(
conf: &DepthAnythingV2Config,
target_patch_size: usize,
activation: Activation,
vb: VarBuilder,
) -> Result<Self> {
const KERNEL_SIZE: usize = 1;
let conv_cfg = Conv2dConfig {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
};
let output_conv = conv2d(
conf.num_features,
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("out_conv"),
)?;
let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
Ok(Self {
res_conv_unit1,
res_conv_unit2,
output_conv,
target_patch_size,
})
}
}
impl Module for FeatureFusionBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let out = self.res_conv_unit2.forward(xs)?;
let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
self.output_conv.forward(&out)
}
}
pub struct Scratch {
layer1_rn: Conv2d,
layer2_rn: Conv2d,
layer3_rn: Conv2d,
layer4_rn: Conv2d,
refine_net1: FeatureFusionBlock,
refine_net2: FeatureFusionBlock,
refine_net3: FeatureFusionBlock,
refine_net4: FeatureFusionBlock,
output_conv1: Conv2d,
output_conv2: Sequential,
}
impl Scratch {
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
const KERNEL_SIZE: usize = 3;
let conv_cfg = Conv2dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
};
let layer1_rn = conv2d_no_bias(
conf.out_channel_sizes[0],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer1_rn"),
)?;
let layer2_rn = conv2d_no_bias(
conf.out_channel_sizes[1],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer2_rn"),
)?;
let layer3_rn = conv2d_no_bias(
conf.out_channel_sizes[2],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer3_rn"),
)?;
let layer4_rn = conv2d_no_bias(
conf.out_channel_sizes[3],
conf.num_features,
KERNEL_SIZE,
conv_cfg,
vb.pp("layer4_rn"),
)?;
let refine_net1 = FeatureFusionBlock::new(
conf,
conf.target_patch_size * 8,
Activation::Relu,
vb.pp("refinenet1"),
)?;
let refine_net2 = FeatureFusionBlock::new(
conf,
conf.target_patch_size * 4,
Activation::Relu,
vb.pp("refinenet2"),
)?;
let refine_net3 = FeatureFusionBlock::new(
conf,
conf.target_patch_size * 2,
Activation::Relu,
vb.pp("refinenet3"),
)?;
let refine_net4 = FeatureFusionBlock::new(
conf,
conf.target_patch_size,
Activation::Relu,
vb.pp("refinenet4"),
)?;
let conv_cfg = Conv2dConfig {
padding: 1,
stride: 1,
dilation: 1,
groups: 1,
};
let output_conv1 = conv2d(
conf.num_features,
conf.num_features / 2,
KERNEL_SIZE,
conv_cfg,
vb.pp("output_conv1"),
)?;
let output_conv2 = seq();
const HEAD_FEATURES_2: usize = 32;
const OUT_CHANNELS_2: usize = 1;
const KERNEL_SIZE_2: usize = 1;
let output_conv2 = output_conv2.add(conv2d(
conf.num_features / 2,
HEAD_FEATURES_2,
KERNEL_SIZE,
conv_cfg,
vb.pp("output_conv2").pp("0"),
)?);
let output_conv2 = output_conv2
.add(Activation::Relu)
.add(conv2d(
HEAD_FEATURES_2,
OUT_CHANNELS_2,
KERNEL_SIZE_2,
conv_cfg,
vb.pp("output_conv2").pp("2"),
)?)
.add(Activation::Relu);
Ok(Self {
layer1_rn,
layer2_rn,
layer3_rn,
layer4_rn,
refine_net1,
refine_net2,
refine_net3,
refine_net4,
output_conv1,
output_conv2,
})
}
}
const NUM_CHANNELS: usize = 4;
pub struct DPTHead<'a> {
conf: &'a DepthAnythingV2Config,
projections: Vec<Conv2d>,
resize_layers: Vec<Box<dyn Module>>,
readout_projections: Vec<Sequential>,
scratch: Scratch,
}
impl<'a> DPTHead<'a> {
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
projections.push(conv2d(
conf.in_channel_size,
*out_channel_size,
1,
Default::default(),
vb.pp("projects").pp(conv_index.to_string()),
)?);
}
let resize_layers: Vec<Box<dyn Module>> = vec![
Box::new(conv_transpose2d(
conf.out_channel_sizes[0],
conf.out_channel_sizes[0],
4,
ConvTranspose2dConfig {
padding: 0,
stride: 4,
dilation: 1,
output_padding: 0,
},
vb.pp("resize_layers").pp("0"),
)?),
Box::new(conv_transpose2d(
conf.out_channel_sizes[1],
conf.out_channel_sizes[1],
2,
ConvTranspose2dConfig {
padding: 0,
stride: 2,
dilation: 1,
output_padding: 0,
},
vb.pp("resize_layers").pp("1"),
)?),
Box::new(Identity::new()),
Box::new(conv2d(
conf.out_channel_sizes[3],
conf.out_channel_sizes[3],
3,
Conv2dConfig {
padding: 1,
stride: 2,
dilation: 1,
groups: 1,
},
vb.pp("resize_layers").pp("3"),
)?),
];
let readout_projections = if conf.use_class_token {
let rop = Vec::with_capacity(NUM_CHANNELS);
for rop_index in 0..NUM_CHANNELS {
seq()
.add(linear(
2 * conf.in_channel_size,
conf.in_channel_size,
vb.pp("readout_projects").pp(rop_index.to_string()),
)?)
.add(Activation::Gelu);
}
rop
} else {
vec![]
};
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
Ok(Self {
conf,
projections,
resize_layers,
readout_projections,
scratch,
})
}
}
impl Module for DPTHead<'_> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
for i in 0..NUM_CHANNELS {
let x = if self.conf.use_class_token {
let x = xs.get(i)?.get(0)?;
let class_token = xs.get(i)?.get(1)?;
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
let to_cat = [x, readout];
let cat = Tensor::cat(&to_cat, Minus1)?;
self.readout_projections[i].forward(&cat)?
} else {
xs.get(i)?
};
let x_dims = x.dims();
let x = x.permute((0, 2, 1))?.reshape((
x_dims[0],
x_dims[x_dims.len() - 1],
self.conf.target_patch_size,
self.conf.target_patch_size,
))?;
let x = self.projections[i].forward(&x)?;
let x = self.resize_layers[i].forward(&x)?;
out.push(x);
}
let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
let res3_out = self
.scratch
.refine_net3
.res_conv_unit1
.forward(&layer_3_rn)?;
let res3_out = path4.add(&res3_out)?;
let path3 = self.scratch.refine_net3.forward(&res3_out)?;
let res2_out = self
.scratch
.refine_net2
.res_conv_unit1
.forward(&layer_2_rn)?;
let res2_out = path3.add(&res2_out)?;
let path2 = self.scratch.refine_net2.forward(&res2_out)?;
let res1_out = self
.scratch
.refine_net1
.res_conv_unit1
.forward(&layer_1_rn)?;
let res1_out = path2.add(&res1_out)?;
let path1 = self.scratch.refine_net1.forward(&res1_out)?;
let out = self.scratch.output_conv1.forward(&path1)?;
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
self.scratch.output_conv2.forward(&out)
}
}
pub struct DepthAnythingV2<'a> {
pretrained: &'a DinoVisionTransformer,
depth_head: DPTHead<'a>,
conf: &'a DepthAnythingV2Config,
}
impl<'a> DepthAnythingV2<'a> {
pub fn new(
pretrained: &'a DinoVisionTransformer,
conf: &'a DepthAnythingV2Config,
vb: VarBuilder,
) -> Result<Self> {
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
Ok(Self {
pretrained,
depth_head,
conf,
})
}
}
impl<'a> Module for DepthAnythingV2<'a> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let features = self.pretrained.get_intermediate_layers(
xs,
&self.conf.layer_ids_vits,
false,
false,
true,
)?;
let depth = self.depth_head.forward(&features)?;
depth.relu()
}
}

View File

@ -258,6 +258,84 @@ impl DinoVisionTransformer {
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)? &xs + &self.interpolate_pos_encoding(&xs, w, h)?
} }
fn get_intermediate_layers_not_chunked(
&self,
xs: &Tensor,
blocks_to_take: &[usize],
) -> Result<Vec<Tensor>> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
let mut output = Vec::new();
for (i, blk) in self.blocks.iter().enumerate() {
xs = blk.forward(&xs)?;
if blocks_to_take.contains(&i) {
output.push(xs.clone());
}
}
if output.len() != blocks_to_take.len() {
candle::bail!(
"only {} / {} blocks found",
output.len(),
blocks_to_take.len()
);
}
Ok(output)
}
pub fn get_intermediate_layers(
&self,
xs: &Tensor,
blocks_to_take: &[usize],
reshape: bool,
return_class_token: bool,
norm: bool,
) -> Result<Tensor> {
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
let outputs = if norm {
outputs
.iter()
.map(|out| self.norm.forward(out))
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
let class_tokens = outputs
.iter()
.map(|out| out.i((.., 0)))
.collect::<Result<Vec<_>>>()?;
let outputs = outputs
.iter()
.map(|out| out.i((.., 1..)))
.collect::<Result<Vec<_>>>()?;
let outputs = if reshape {
let (b, _c, w, h) = xs.dims4()?;
let patch_size = self.patch_embed.patch_size.0;
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
outputs
.iter()
.map(|out| {
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
.transpose(2, 3)?
.transpose(1, 2)
})
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
let outputs = if return_class_token {
outputs
.iter()
.zip(class_tokens.iter())
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
Tensor::stack(&outputs[..], 0)
}
} }
impl Module for DinoVisionTransformer { impl Module for DinoVisionTransformer {

View File

@ -6,6 +6,7 @@ pub mod chatglm;
pub mod clip; pub mod clip;
pub mod convmixer; pub mod convmixer;
pub mod convnext; pub mod convnext;
pub mod depth_anything_v2;
pub mod dinov2; pub mod dinov2;
pub mod distilbert; pub mod distilbert;
pub mod efficientnet; pub mod efficientnet;

View File

@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM { impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let base_model = Model::new(cfg, vb.clone())?;
let base_model = Model::new(cfg, vb)?; let lm_head = if vb.contains_tensor("lm_head") {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
} else {
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
};
Ok(Self { Ok(Self {
base_model, base_model,
lm_head, lm_head,

View File

@ -54,8 +54,7 @@ impl ModuleT for Vgg<'_> {
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> { fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
let layers = convs let layers = convs
.iter() .iter()
.enumerate() .map(|&(in_c, out_c, name)| {
.map(|(_, &(in_c, out_c, name))| {
candle_nn::conv2d( candle_nn::conv2d(
in_c, in_c,
out_c, out_c,