mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
11 Commits
operators-
...
0.6.0
Author | SHA1 | Date | |
---|---|---|---|
a3dd87f15e | |||
242e006bbb | |||
6baa1d486b | |||
36cf54525d | |||
2b10aaa05d | |||
9f804af29d | |||
54ff971e35 | |||
b9fac7ec00 | |||
f65e90e7ef | |||
d39462856b | |||
cb180eb23a |
15
.github/workflows/trufflehog.yml
vendored
Normal file
15
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
on:
|
||||||
|
push:
|
||||||
|
|
||||||
|
name: Secret Leaks
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
trufflehog:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Secret Scanning
|
||||||
|
uses: trufflesecurity/trufflehog@main
|
18
Cargo.toml
18
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.5.1"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.6.0" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.5.1" }
|
candle-datasets = { path = "./candle-datasets", version = "0.6.0" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.6.0" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.5.1" }
|
candle-kernels = { path = "./candle-kernels", version = "0.6.0" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.6.0" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.5.1" }
|
candle-nn = { path = "./candle-nn", version = "0.6.0" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
|
candle-onnx = { path = "./candle-onnx", version = "0.6.0" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
candle-transformers = { path = "./candle-transformers", version = "0.6.0" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
|
@ -9,8 +9,10 @@ use candle_core::{Device, Tensor};
|
|||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||||
let _x1 = x.matmul(&x)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
drop(_x1);
|
drop(_x1);
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
@ -19,6 +21,7 @@ fn main() -> Result<()> {
|
|||||||
println!("fp32: {:?}", start_time.elapsed());
|
println!("fp32: {:?}", start_time.elapsed());
|
||||||
drop(_x1);
|
drop(_x1);
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||||
let _x1 = x.matmul(&x)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
drop(_x1);
|
drop(_x1);
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
|
@ -121,7 +121,8 @@ impl ReduceIndex {
|
|||||||
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
||||||
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
||||||
let dst_to_set = dst.spare_capacity_mut();
|
let dst_to_set = dst.spare_capacity_mut();
|
||||||
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
|
let dst_to_set =
|
||||||
|
unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
|
||||||
match src_l.contiguous_offsets() {
|
match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => {
|
Some((o1, o2)) => {
|
||||||
let src = &src[o1..o2];
|
let src = &src[o1..o2];
|
||||||
|
@ -174,7 +174,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
|||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||||
// SAFETY: values are all set by f_vec.
|
// SAFETY: values are all set by f_vec.
|
||||||
unsafe { ys.set_len(el_count) };
|
unsafe { ys.set_len(el_count) };
|
||||||
@ -185,7 +187,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
|||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
let mut dst_i = 0;
|
let mut dst_i = 0;
|
||||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||||
f_vec(
|
f_vec(
|
||||||
@ -224,7 +228,9 @@ pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [
|
|||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
|
||||||
|
};
|
||||||
let mut dst_i = 0;
|
let mut dst_i = 0;
|
||||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||||
f_vec(
|
f_vec(
|
||||||
@ -311,7 +317,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
|||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||||
|
};
|
||||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||||
// SAFETY: values are all set by f_vec.
|
// SAFETY: values are all set by f_vec.
|
||||||
unsafe { ys.set_len(len) };
|
unsafe { ys.set_len(len) };
|
||||||
@ -333,7 +341,9 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U
|
|||||||
} else {
|
} else {
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
let ys_to_set = unsafe {
|
||||||
|
std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
|
||||||
|
};
|
||||||
let mut dst_index = 0;
|
let mut dst_index = 0;
|
||||||
for src_index in block_start_index {
|
for src_index in block_start_index {
|
||||||
let vs = &vs[src_index..src_index + block_len];
|
let vs = &vs[src_index..src_index + block_len];
|
||||||
|
@ -2035,15 +2035,13 @@ unsafe fn gemm_strided_batched_bf16(
|
|||||||
|
|
||||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||||
let alpha = f16::from_f32(alpha_f32);
|
|
||||||
let beta = f16::from_f32(beta_f32);
|
|
||||||
// The type for alpha and beta depends on the computeType.
|
// The type for alpha and beta depends on the computeType.
|
||||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
|
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
|
||||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
|
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
|
||||||
(
|
(
|
||||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF,
|
||||||
(&alpha) as *const f16 as *const _,
|
(&alpha_f32) as *const f32 as *const _,
|
||||||
(&beta) as *const f16 as *const _,
|
(&beta_f32) as *const f32 as *const _,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
(
|
(
|
||||||
|
@ -848,7 +848,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
.device
|
.device
|
||||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer()?;
|
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "col2im1d_f32",
|
DType::F32 => "col2im1d_f32",
|
||||||
DType::U32 => "col2im1d_u32",
|
DType::U32 => "col2im1d_u32",
|
||||||
@ -869,6 +868,12 @@ impl BackendStorage for MetalStorage {
|
|||||||
&kernel_l_mm,
|
&kernel_l_mm,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
// It is important for the command buffer to be obtained *after* the matmul
|
||||||
|
// kernel has run, otherwise we might use a command-buffer that has been commited
|
||||||
|
// already resulting in the following error.
|
||||||
|
// _status < MTLCommandBufferStatusCommitted >
|
||||||
|
// -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:]
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_col2im1d(
|
candle_metal_kernels::call_col2im1d(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
|
@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
|||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
palette = { version = "0.7.6", optional = true }
|
||||||
|
enterpolation = { version = "0.2.1", optional = true}
|
||||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
rubato = { version = "0.15.0", optional = true }
|
rubato = { version = "0.15.0", optional = true }
|
||||||
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
|
|||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
microphone = ["cpal"]
|
microphone = ["cpal"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["encodec"]
|
required-features = ["encodec"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "depth_anything_v2"
|
||||||
|
required-features = ["depth_anything_v2"]
|
||||||
|
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-dinov2
|
||||||
|
|
||||||
|
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
|
||||||
|
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
|
||||||
|
|
||||||
|
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
|
||||||
|
|
||||||
|
## Running an example with color map and CUDA
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
```
|
||||||
|
|
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use enterpolation::linear::ConstEquidistantLinear;
|
||||||
|
use enterpolation::Generator;
|
||||||
|
use palette::LinSrgb;
|
||||||
|
|
||||||
|
use candle::Tensor;
|
||||||
|
|
||||||
|
pub struct SpectralRColormap {
|
||||||
|
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpectralRColormap {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// Define a colormap similar to 'Spectral_r' by specifying key colors.
|
||||||
|
// got the colors from ChatGPT-4o
|
||||||
|
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
|
||||||
|
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
|
||||||
|
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
|
||||||
|
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
|
||||||
|
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
|
||||||
|
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
|
||||||
|
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
|
||||||
|
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
|
||||||
|
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
|
||||||
|
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
|
||||||
|
]);
|
||||||
|
Self { gradient }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_color(&self, value: f32) -> LinSrgb {
|
||||||
|
self.gradient.gen(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
println!("Gray: {:?}", gray.dims());
|
||||||
|
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
|
||||||
|
let rgb_values: Vec<f32> = gray_values
|
||||||
|
.iter()
|
||||||
|
.map(|g| self.get_color(*g))
|
||||||
|
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let [.., height, width] = gray.dims() else {
|
||||||
|
candle::bail!("Not enough dims!")
|
||||||
|
};
|
||||||
|
|
||||||
|
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
|
||||||
|
|
||||||
|
color.permute((2, 0, 1))
|
||||||
|
}
|
||||||
|
}
|
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
//! Depth Anything V2
|
||||||
|
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use std::ffi::OsString;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::DType::{F32, U8};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor};
|
||||||
|
use candle_examples::{load_image, load_image_and_resize, save_image};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
|
||||||
|
use candle_transformers::models::dinov2;
|
||||||
|
|
||||||
|
use crate::color_map::SpectralRColormap;
|
||||||
|
|
||||||
|
mod color_map;
|
||||||
|
|
||||||
|
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
|
||||||
|
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||||
|
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||||
|
|
||||||
|
const DINO_IMG_SIZE: usize = 518;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
dinov2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
depth_anything_v2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: PathBuf,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
output_dir: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
color_map: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let dinov2_model_file = match args.dinov2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-dino-v2".into());
|
||||||
|
api.get("dinov2_vits14.safetensors")?
|
||||||
|
}
|
||||||
|
Some(dinov2_model) => dinov2_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", dinov2_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
|
||||||
|
let dinov2 = dinov2::vit_small(vb)?;
|
||||||
|
println!("DinoV2 model built");
|
||||||
|
|
||||||
|
let depth_anything_model_file = match args.depth_anything_v2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
|
||||||
|
api.get("depth_anything_v2_vits.safetensors")?
|
||||||
|
}
|
||||||
|
Some(depth_anything_model) => depth_anything_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", depth_anything_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = DepthAnythingV2Config::vit_small();
|
||||||
|
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
|
||||||
|
|
||||||
|
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
|
||||||
|
|
||||||
|
println!("Loaded image {image:?}");
|
||||||
|
|
||||||
|
let depth = depth_anything.forward(&image)?;
|
||||||
|
|
||||||
|
println!("Got predictions {:?}", depth.shape());
|
||||||
|
|
||||||
|
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
|
||||||
|
|
||||||
|
let output_path = full_output_path(&args.image, &args.output_dir);
|
||||||
|
println!("Saving image to {}", output_path.to_string_lossy());
|
||||||
|
save_image(&output_image, output_path)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
|
||||||
|
let input_file_name = image_path.file_name().unwrap();
|
||||||
|
let mut output_file_name = OsString::from("depth_");
|
||||||
|
output_file_name.push(input_file_name);
|
||||||
|
let mut output_path = match output_dir {
|
||||||
|
None => image_path.parent().unwrap().to_path_buf(),
|
||||||
|
Some(output_path) => output_path.clone(),
|
||||||
|
};
|
||||||
|
output_path.push(output_file_name);
|
||||||
|
|
||||||
|
output_path
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_and_prep_image(
|
||||||
|
image_path: &PathBuf,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(usize, usize, Tensor)> {
|
||||||
|
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
|
||||||
|
|
||||||
|
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.to_dtype(F32)?
|
||||||
|
.to_device(&device)?;
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.broadcast_as(image.shape())?;
|
||||||
|
let image = (image / max_pixel_val)?;
|
||||||
|
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
|
||||||
|
|
||||||
|
Ok((original_height, original_width, image))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
|
||||||
|
let mean_tensor =
|
||||||
|
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
let std_tensor =
|
||||||
|
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
image.sub(&mean_tensor)?.div(&std_tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_process_image(
|
||||||
|
image: &Tensor,
|
||||||
|
original_height: usize,
|
||||||
|
original_width: usize,
|
||||||
|
color_map: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let out = image.interpolate2d(original_height, original_width)?;
|
||||||
|
let out = scale_image(&out)?;
|
||||||
|
|
||||||
|
let out = if color_map {
|
||||||
|
let spectral_r = SpectralRColormap::new();
|
||||||
|
spectral_r.gray2color(&out)?
|
||||||
|
} else {
|
||||||
|
let rgb_slice = [&out, &out, &out];
|
||||||
|
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(out.device())?
|
||||||
|
.broadcast_as(out.shape())?;
|
||||||
|
let out = (out * max_pixel_val)?;
|
||||||
|
|
||||||
|
out.to_dtype(U8)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scale_image(depth: &Tensor) -> Result<Tensor> {
|
||||||
|
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
|
||||||
|
|
||||||
|
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
|
||||||
|
let min_val_tensor = Tensor::try_from(*min_val)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
let depth = (depth - min_val_tensor)?;
|
||||||
|
|
||||||
|
let range = max_val - min_val;
|
||||||
|
let range_tensor = Tensor::try_from(range)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
|
||||||
|
depth / range_tensor
|
||||||
|
}
|
@ -144,6 +144,14 @@ enum WhichModel {
|
|||||||
W72b,
|
W72b,
|
||||||
#[value(name = "moe-a2.7b")]
|
#[value(name = "moe-a2.7b")]
|
||||||
MoeA27b,
|
MoeA27b,
|
||||||
|
#[value(name = "2-0.5b")]
|
||||||
|
W2_0_5b,
|
||||||
|
#[value(name = "2-1.5b")]
|
||||||
|
W2_1_5b,
|
||||||
|
#[value(name = "2-7b")]
|
||||||
|
W2_7b,
|
||||||
|
#[value(name = "2-72b")]
|
||||||
|
W2_72b,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -234,16 +242,20 @@ fn main() -> Result<()> {
|
|||||||
let model_id = match args.model_id {
|
let model_id = match args.model_id {
|
||||||
Some(model_id) => model_id,
|
Some(model_id) => model_id,
|
||||||
None => {
|
None => {
|
||||||
let size = match args.model {
|
let (version, size) = match args.model {
|
||||||
WhichModel::W0_5b => "0.5B",
|
WhichModel::W2_0_5b => ("2", "0.5B"),
|
||||||
WhichModel::W1_8b => "1.8B",
|
WhichModel::W2_1_5b => ("2", "1.5B"),
|
||||||
WhichModel::W4b => "4B",
|
WhichModel::W2_7b => ("2", "7B"),
|
||||||
WhichModel::W7b => "7B",
|
WhichModel::W2_72b => ("2", "72B"),
|
||||||
WhichModel::W14b => "14B",
|
WhichModel::W0_5b => ("1.5", "0.5B"),
|
||||||
WhichModel::W72b => "72B",
|
WhichModel::W1_8b => ("1.5", "1.8B"),
|
||||||
WhichModel::MoeA27b => "MoE-A2.7B",
|
WhichModel::W4b => ("1.5", "4B"),
|
||||||
|
WhichModel::W7b => ("1.5", "7B"),
|
||||||
|
WhichModel::W14b => ("1.5", "14B"),
|
||||||
|
WhichModel::W72b => ("1.5", "72B"),
|
||||||
|
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||||
};
|
};
|
||||||
format!("Qwen/Qwen1.5-{size}")
|
format!("Qwen/Qwen{version}-{size}")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
@ -261,11 +273,15 @@ fn main() -> Result<()> {
|
|||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
||||||
|
vec![repo.get("model.safetensors")?]
|
||||||
|
}
|
||||||
WhichModel::W4b
|
WhichModel::W4b
|
||||||
| WhichModel::W7b
|
| WhichModel::W7b
|
||||||
|
| WhichModel::W2_7b
|
||||||
| WhichModel::W14b
|
| WhichModel::W14b
|
||||||
| WhichModel::W72b
|
| WhichModel::W72b
|
||||||
|
| WhichModel::W2_72b
|
||||||
| WhichModel::MoeA27b => {
|
| WhichModel::MoeA27b => {
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.5.1"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.1" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.6.0" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.5.1"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.5.1"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D};
|
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||||
@ -926,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
|
|||||||
n => candle::bail!("replication-pad with a size of {n} is not supported"),
|
n => candle::bail!("replication-pad with a size of {n} is not supported"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Identity;
|
||||||
|
|
||||||
|
impl Identity {
|
||||||
|
pub fn new() -> Identity {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Identity {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Identity {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(xs.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.5.1"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.5.1" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.6.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.5.1" }
|
candle-nn = { path = "../candle-nn", version = "0.6.0" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::onnx;
|
|
||||||
use crate::onnx::attribute_proto::AttributeType;
|
use crate::onnx::attribute_proto::AttributeType;
|
||||||
use crate::onnx::tensor_proto::DataType;
|
use crate::onnx::tensor_proto::DataType;
|
||||||
|
use crate::onnx::{self, GraphProto};
|
||||||
use candle::{bail, DType, Device, Result, Tensor};
|
use candle::{bail, DType, Device, Result, Tensor};
|
||||||
use std::{collections::HashMap, usize};
|
use std::{collections::HashMap, usize};
|
||||||
|
|
||||||
@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
|
|||||||
DataType::Float16 => Some(DType::F16),
|
DataType::Float16 => Some(DType::F16),
|
||||||
DataType::Float => Some(DType::F32),
|
DataType::Float => Some(DType::F32),
|
||||||
DataType::Double => Some(DType::F64),
|
DataType::Double => Some(DType::F64),
|
||||||
|
DataType::Bool => Some(DType::U8),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -56,6 +57,15 @@ impl Attr for str {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Attr for GraphProto {
|
||||||
|
const TYPE: AttributeType = AttributeType::Graph;
|
||||||
|
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||||
|
attr.g
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl AttrOwned for Tensor {
|
impl AttrOwned for Tensor {
|
||||||
const TYPE: AttributeType = AttributeType::Tensor;
|
const TYPE: AttributeType = AttributeType::Tensor;
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||||
@ -214,13 +224,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
|||||||
// anymore.
|
// anymore.
|
||||||
pub fn simple_eval(
|
pub fn simple_eval(
|
||||||
model: &onnx::ModelProto,
|
model: &onnx::ModelProto,
|
||||||
inputs: HashMap<String, Value>,
|
mut inputs: HashMap<String, Value>,
|
||||||
) -> Result<HashMap<String, Value>> {
|
) -> Result<HashMap<String, Value>> {
|
||||||
let graph = match &model.graph {
|
let graph = match &model.graph {
|
||||||
None => bail!("no graph defined in proto"),
|
None => bail!("no graph defined in proto"),
|
||||||
Some(graph) => graph,
|
Some(graph) => graph,
|
||||||
};
|
};
|
||||||
let mut values = inputs;
|
simple_eval_(graph, &mut inputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn simple_eval_(
|
||||||
|
graph: &onnx::GraphProto,
|
||||||
|
values: &mut HashMap<String, Value>,
|
||||||
|
) -> Result<HashMap<String, Value>> {
|
||||||
for t in graph.initializer.iter() {
|
for t in graph.initializer.iter() {
|
||||||
let tensor = get_tensor(t, t.name.as_str())?;
|
let tensor = get_tensor(t, t.name.as_str())?;
|
||||||
values.insert(t.name.to_string(), tensor);
|
values.insert(t.name.to_string(), tensor);
|
||||||
@ -958,6 +974,165 @@ pub fn simple_eval(
|
|||||||
let input = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
values.insert(node.output[0].clone(), input.clone());
|
values.insert(node.output[0].clone(), input.clone());
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
|
||||||
|
"If" => {
|
||||||
|
// protobuf encodes boolean false as 0 and true as 1
|
||||||
|
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
|
||||||
|
let attr_name = if cond != 0 {
|
||||||
|
"then_branch"
|
||||||
|
} else {
|
||||||
|
"else_branch"
|
||||||
|
};
|
||||||
|
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
|
||||||
|
if sub_graph.output.len() != node.output.len() {
|
||||||
|
bail!(
|
||||||
|
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
|
||||||
|
node.name,
|
||||||
|
sub_graph.output.len(),
|
||||||
|
node.output.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let branch_out = simple_eval_(sub_graph, values)?;
|
||||||
|
for (i, out) in node.output.iter().enumerate() {
|
||||||
|
values.insert(
|
||||||
|
out.clone(),
|
||||||
|
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
|
||||||
|
"Pad" => {
|
||||||
|
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
|
||||||
|
let data = get(&node.input[0])?;
|
||||||
|
let pads = get(&node.input[1])?;
|
||||||
|
if node.input.len() > 2 {
|
||||||
|
bail!(
|
||||||
|
"unsupported number of inputs {} for Pad node {:?}, expected 2",
|
||||||
|
node.input.len(),
|
||||||
|
node.name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if pads.rank() != 1 {
|
||||||
|
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
|
||||||
|
}
|
||||||
|
if pads.dim(0).unwrap() != 2 * data.rank() {
|
||||||
|
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
|
||||||
|
}
|
||||||
|
|
||||||
|
let pads = pads.to_vec1::<i64>()?;
|
||||||
|
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
|
||||||
|
|
||||||
|
match mode {
|
||||||
|
"reflect" => {
|
||||||
|
let mut out = data.clone();
|
||||||
|
for (i, &dim) in data.dims().iter().enumerate().rev() {
|
||||||
|
if pads_pre[i] == 0 && pads_post[i] == 0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
|
||||||
|
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
|
||||||
|
}
|
||||||
|
let idx = if dim > 1 {
|
||||||
|
let cycle_len = dim * 2 - 1;
|
||||||
|
let skip = (pads_pre[i] as usize) % cycle_len;
|
||||||
|
let idx = zigzag(0, (dim - 1) as i64)
|
||||||
|
.skip(skip)
|
||||||
|
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
|
||||||
|
Tensor::from_iter(idx, out.device())?
|
||||||
|
} else {
|
||||||
|
Tensor::full(0i64, (dim,), out.device())?
|
||||||
|
};
|
||||||
|
|
||||||
|
out = out.index_select(&idx, i)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), out);
|
||||||
|
}
|
||||||
|
_ => bail!(
|
||||||
|
"unsupported 'mode' value {mode:?} for Pad node {:?}",
|
||||||
|
node.name
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
|
||||||
|
"Slice" => {
|
||||||
|
let data = get(&node.input[0])?;
|
||||||
|
let starts = get(&node.input[1])?;
|
||||||
|
let ends = get(&node.input[2])?;
|
||||||
|
let default_axes;
|
||||||
|
let default_steps;
|
||||||
|
let axes: &Tensor;
|
||||||
|
let steps: &Tensor;
|
||||||
|
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
|
||||||
|
// they are set to [1, ..., 1] of length len(starts)
|
||||||
|
match node.input.len() {
|
||||||
|
3 => {
|
||||||
|
let len = starts.dims()[0];
|
||||||
|
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
|
||||||
|
axes = default_axes.as_ref().unwrap();
|
||||||
|
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
|
||||||
|
steps = default_steps.as_ref().unwrap();
|
||||||
|
}
|
||||||
|
4 => {
|
||||||
|
let len = starts.dims()[0];
|
||||||
|
axes = get(&node.input[3])?;
|
||||||
|
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
|
||||||
|
steps = default_steps.as_ref().unwrap();
|
||||||
|
}
|
||||||
|
5 => {
|
||||||
|
steps = get(&node.input[4])?;
|
||||||
|
axes = get(&node.input[3])?;
|
||||||
|
}
|
||||||
|
_ => bail!(
|
||||||
|
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
|
||||||
|
node.input.len(),
|
||||||
|
node
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out = data.clone();
|
||||||
|
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
|
||||||
|
// All negative elements of axes are made non-negative by
|
||||||
|
// adding r to them, where r = rank(input).
|
||||||
|
let axis = if axis < 0 {
|
||||||
|
axis + data.rank() as i64
|
||||||
|
} else {
|
||||||
|
axis
|
||||||
|
} as usize;
|
||||||
|
|
||||||
|
let data_dim = data.dims()[axis] as i64;
|
||||||
|
let mut s = starts.get(i)?.to_scalar::<i64>()?;
|
||||||
|
let mut e = ends.get(i)?.to_scalar::<i64>()?;
|
||||||
|
// All negative values in starts[i] and ends[i] have
|
||||||
|
// dims[axes[i]] added to them, where dims are the
|
||||||
|
// dimensions of input.
|
||||||
|
if s < 0 {
|
||||||
|
s += data_dim;
|
||||||
|
}
|
||||||
|
if e < 0 {
|
||||||
|
e += data_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
let p = steps.get(i)?.to_scalar::<i64>()?;
|
||||||
|
// starts[i] is clamped into the range [0, dims[axes[i]]]
|
||||||
|
// for positive stepping and [0, dims[axes[i]]-1] for
|
||||||
|
// negative stepping.
|
||||||
|
// for positive stepping ends[axes[i]] is clamped to
|
||||||
|
// [0, dims[axes[i]]], while for negative stepping it is
|
||||||
|
// clamped to [-1, dims[axes[i]]-1].
|
||||||
|
if p >= 0 {
|
||||||
|
s = s.clamp(0, data_dim);
|
||||||
|
e = e.clamp(0, data_dim);
|
||||||
|
} else {
|
||||||
|
s = s.clamp(0, data_dim - 1);
|
||||||
|
e = e.clamp(-1, data_dim - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
||||||
|
out = out.index_select(&indexes, axis)?
|
||||||
|
}
|
||||||
|
values.insert(node.output[0].clone(), out);
|
||||||
|
}
|
||||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||||
"ReduceMean" => {
|
"ReduceMean" => {
|
||||||
@ -1099,6 +1274,30 @@ pub fn simple_eval(
|
|||||||
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
|
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
|
||||||
|
"Gemm" => {
|
||||||
|
let a = get(&node.input[0])?;
|
||||||
|
let b = get(&node.input[1])?;
|
||||||
|
let c = get(&node.input[2])?;
|
||||||
|
|
||||||
|
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(1.0);
|
||||||
|
let beta = get_attr_opt::<f32>(node, "beta")?.copied().unwrap_or(1.0);
|
||||||
|
|
||||||
|
let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?;
|
||||||
|
let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let trans_a = get_attr_opt::<i64>(node, "transA")?.copied().unwrap_or(0);
|
||||||
|
let trans_b = get_attr_opt::<i64>(node, "transB")?.copied().unwrap_or(0);
|
||||||
|
|
||||||
|
let a = if trans_a == 0 { a.clone() } else { a.t()? };
|
||||||
|
let b = if trans_b == 0 { b.clone() } else { b.t()? };
|
||||||
|
|
||||||
|
let output = a
|
||||||
|
.broadcast_mul(&alpha)?
|
||||||
|
.broadcast_matmul(&b)?
|
||||||
|
.broadcast_add(&c.broadcast_mul(&beta)?)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,10 +4,12 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::test_utils::to_vec2_round;
|
||||||
use candle::{DType, Device, NdArray, Result, Tensor};
|
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||||
use candle_onnx::onnx;
|
|
||||||
use candle_onnx::onnx::attribute_proto::AttributeType;
|
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||||
use candle_onnx::onnx::tensor_proto::DataType;
|
use candle_onnx::onnx::tensor_proto::DataType;
|
||||||
|
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
|
||||||
|
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
|
||||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -35,14 +37,11 @@ fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
||||||
let manual_graph = create_model_proto_with_graph(None);
|
let manual_graph = create_model_proto_with_graph(None);
|
||||||
|
|
||||||
let inputs: HashMap<String, Tensor> = HashMap::new();
|
let inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
|
||||||
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
||||||
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
||||||
Ok(_) => panic!("Expected an error due to undefined graph"),
|
Ok(_) => panic!("Expected an error due to undefined graph"),
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,14 +80,8 @@ fn test_add_operation() -> Result<()> {
|
|||||||
assert_eq!(eval.len(), 1);
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
let first = z
|
let first = z.to_vec1::<f64>()?[0];
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
assert_eq!(first, 4.0f64);
|
assert_eq!(first, 4.0f64);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,14 +120,8 @@ fn test_sub_operation() -> Result<()> {
|
|||||||
assert_eq!(eval.len(), 1);
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
let first = z
|
let first = z.to_vec1::<f64>()?[0];
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
assert_eq!(first, 0.0f64);
|
assert_eq!(first, 0.0f64);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,14 +160,8 @@ fn test_mul_operation() -> Result<()> {
|
|||||||
assert_eq!(eval.len(), 1);
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
let first = z
|
let first = z.to_vec1::<f64>()?[0];
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
assert_eq!(first, 4.0f64);
|
assert_eq!(first, 4.0f64);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,15 +200,8 @@ fn test_div_operation() -> Result<()> {
|
|||||||
assert_eq!(eval.len(), 1);
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
let first = z
|
let first = z.to_vec1::<f64>()?[0];
|
||||||
.to_vec1::<f64>()?
|
|
||||||
.to_vec()
|
|
||||||
.get(0)
|
|
||||||
.expect("Failed to get first element")
|
|
||||||
.clone();
|
|
||||||
|
|
||||||
assert_eq!(first, 1.0f64);
|
assert_eq!(first, 1.0f64);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,7 +246,7 @@ fn test_exp_operation() -> Result<()> {
|
|||||||
|
|
||||||
assert_eq!(results[0][0], 0.36787944f32);
|
assert_eq!(results[0][0], 0.36787944f32);
|
||||||
assert_eq!(results[0][1], 1.0f32);
|
assert_eq!(results[0][1], 1.0f32);
|
||||||
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
|
assert_eq!(results[1], vec![std::f32::consts::E, 7.389056f32]);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -914,7 +888,7 @@ fn test_constant_of_shape() -> Result<()> {
|
|||||||
),
|
),
|
||||||
_ => panic!("unsupported DType in test"),
|
_ => panic!("unsupported DType in test"),
|
||||||
};
|
};
|
||||||
let tensor = onnx::TensorProto {
|
let tensor = TensorProto {
|
||||||
data_type: data_type.into(),
|
data_type: data_type.into(),
|
||||||
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
|
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
|
||||||
raw_data: value,
|
raw_data: value,
|
||||||
@ -1293,14 +1267,7 @@ fn test_cos_operation() -> Result<()> {
|
|||||||
assert_eq!(eval.len(), 1);
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
assert_eq!(to_vec2_round(z, 4)?, [[1.0, 0.5403], [-0.4161, -0.99]]);
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
results,
|
|
||||||
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1342,19 +1309,12 @@ fn test_sin_operation() -> Result<()> {
|
|||||||
quantization_annotation: vec![],
|
quantization_annotation: vec![],
|
||||||
}));
|
}));
|
||||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
inputs.insert(INPUT_X.to_string(), x);
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
assert_eq!(eval.len(), 1);
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
assert_eq!(to_vec2_round(z, 4)?, [[0.0, 0.8415], [0.9093, 0.1411]]);
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3150,3 +3110,300 @@ fn test_leakyrelu() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "If"
|
||||||
|
#[test]
|
||||||
|
fn test_if() -> Result<()> {
|
||||||
|
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||||
|
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
|
||||||
|
let output_type_proto = Some(TypeProto {
|
||||||
|
value: Some(type_proto::Value::TensorType(type_proto::Tensor {
|
||||||
|
elem_type: DataType::Float.into(),
|
||||||
|
shape: Some(TensorShapeProto {
|
||||||
|
dim: vec![Dimension {
|
||||||
|
denotation: "".to_string(),
|
||||||
|
value: Some(dimension::Value::DimValue(5)),
|
||||||
|
}],
|
||||||
|
}),
|
||||||
|
})),
|
||||||
|
denotation: "".to_string(),
|
||||||
|
});
|
||||||
|
let then_branch = GraphProto {
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: "then_out".to_string(),
|
||||||
|
r#type: output_type_proto.clone(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Constant".to_string(),
|
||||||
|
input: vec![],
|
||||||
|
output: vec!["then_out".to_string()],
|
||||||
|
attribute: vec![AttributeProto {
|
||||||
|
name: "value".to_string(),
|
||||||
|
r#type: AttributeType::Tensor.into(),
|
||||||
|
t: Some(TensorProto {
|
||||||
|
dims: vec![x.len() as i64],
|
||||||
|
float_data: x.clone(),
|
||||||
|
data_type: DataType::Float.into(),
|
||||||
|
..TensorProto::default()
|
||||||
|
}),
|
||||||
|
..AttributeProto::default()
|
||||||
|
}],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
..GraphProto::default()
|
||||||
|
};
|
||||||
|
let else_branch = GraphProto {
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: "else_out".to_string(),
|
||||||
|
r#type: output_type_proto.clone(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Constant".to_string(),
|
||||||
|
input: vec![],
|
||||||
|
output: vec!["else_out".to_string()],
|
||||||
|
attribute: vec![AttributeProto {
|
||||||
|
name: "value".to_string(),
|
||||||
|
r#type: AttributeType::Tensor.into(),
|
||||||
|
t: Some(TensorProto {
|
||||||
|
dims: vec![y.len() as i64],
|
||||||
|
float_data: y.clone(),
|
||||||
|
data_type: DataType::Float.into(),
|
||||||
|
..TensorProto::default()
|
||||||
|
}),
|
||||||
|
..AttributeProto::default()
|
||||||
|
}],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
..GraphProto::default()
|
||||||
|
};
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "If".to_string(),
|
||||||
|
attribute: vec![
|
||||||
|
AttributeProto {
|
||||||
|
name: "then_branch".to_string(),
|
||||||
|
r#type: AttributeType::Graph.into(),
|
||||||
|
g: Some(then_branch),
|
||||||
|
..AttributeProto::default()
|
||||||
|
},
|
||||||
|
AttributeProto {
|
||||||
|
name: "else_branch".to_string(),
|
||||||
|
r#type: AttributeType::Graph.into(),
|
||||||
|
g: Some(else_branch),
|
||||||
|
..AttributeProto::default()
|
||||||
|
},
|
||||||
|
],
|
||||||
|
input: vec!["cond".to_string()],
|
||||||
|
output: vec!["res".to_string()],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: "res".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: output_type_proto.clone(),
|
||||||
|
}],
|
||||||
|
..GraphProto::default()
|
||||||
|
}));
|
||||||
|
|
||||||
|
for cond in [1u8, 0] {
|
||||||
|
let inputs =
|
||||||
|
HashMap::from_iter([("cond".to_string(), Tensor::full(cond, (1,), &Device::Cpu)?)]);
|
||||||
|
let outputs = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
let expected = if cond != 0 { &x } else { &y };
|
||||||
|
let Some(res) = outputs.get("res") else {
|
||||||
|
candle::bail!("outputs didn't contain expected key `res`: {outputs:?}");
|
||||||
|
};
|
||||||
|
assert_eq!(&res.to_vec1::<f32>()?, expected);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pad() -> Result<()> {
|
||||||
|
let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?;
|
||||||
|
let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?;
|
||||||
|
let mode = "reflect";
|
||||||
|
|
||||||
|
let expected = Tensor::from_vec(
|
||||||
|
vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7],
|
||||||
|
(3, 4),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: "data".to_string(),
|
||||||
|
..ValueInfoProto::default()
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: "pads".to_string(),
|
||||||
|
..ValueInfoProto::default()
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: "output".to_string(),
|
||||||
|
..ValueInfoProto::default()
|
||||||
|
}],
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Pad".to_string(),
|
||||||
|
input: vec!["data".to_string(), "pads".to_string()],
|
||||||
|
output: vec!["output".to_string()],
|
||||||
|
attribute: vec![AttributeProto {
|
||||||
|
name: "mode".to_string(),
|
||||||
|
r#type: AttributeType::String.into(),
|
||||||
|
s: mode.as_bytes().to_vec(),
|
||||||
|
..AttributeProto::default()
|
||||||
|
}],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
..GraphProto::default()
|
||||||
|
}));
|
||||||
|
|
||||||
|
let inputs = HashMap::from_iter([("data".to_string(), data), ("pads".to_string(), pads)]);
|
||||||
|
let res = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
|
let Some(actual) = res.get("output") else {
|
||||||
|
candle::bail!("outputs didn't contain expected key `output`: {res:?}");
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_slice() -> Result<()> {
|
||||||
|
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Slice".to_string(),
|
||||||
|
input: vec![
|
||||||
|
"data".to_string(),
|
||||||
|
"starts".to_string(),
|
||||||
|
"ends".to_string(),
|
||||||
|
"axes".to_string(),
|
||||||
|
"steps".to_string(),
|
||||||
|
],
|
||||||
|
output: vec!["result".to_string()],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
input: ["data", "starts", "ends", "axes", "steps"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
output: ["result"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
..GraphProto::default()
|
||||||
|
}));
|
||||||
|
|
||||||
|
/*
|
||||||
|
data = [
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
]
|
||||||
|
axes = [0, 1]
|
||||||
|
starts = [1, 0]
|
||||||
|
ends = [2, 3]
|
||||||
|
steps = [1, 2]
|
||||||
|
result = [
|
||||||
|
[5, 7],
|
||||||
|
]
|
||||||
|
*/
|
||||||
|
|
||||||
|
let outputs = candle_onnx::simple_eval(
|
||||||
|
&model,
|
||||||
|
HashMap::from_iter([
|
||||||
|
(
|
||||||
|
"data".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"starts".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ends".to_string(),
|
||||||
|
Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"axes".to_string(),
|
||||||
|
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"steps".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
)?;
|
||||||
|
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
|
||||||
|
assert_eq!(actual, vec![vec![5i64, 7]]);
|
||||||
|
|
||||||
|
/*
|
||||||
|
data = [
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
]
|
||||||
|
starts = [0, 1]
|
||||||
|
ends = [-1, 1000]
|
||||||
|
result = [
|
||||||
|
[2, 3, 4],
|
||||||
|
]
|
||||||
|
*/
|
||||||
|
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Slice".to_string(),
|
||||||
|
input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()],
|
||||||
|
output: vec!["result".to_string()],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
input: ["data", "starts", "ends"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
output: ["result"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
..GraphProto::default()
|
||||||
|
}));
|
||||||
|
let outputs = candle_onnx::simple_eval(
|
||||||
|
&model,
|
||||||
|
HashMap::from_iter([
|
||||||
|
(
|
||||||
|
"data".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"starts".to_string(),
|
||||||
|
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ends".to_string(),
|
||||||
|
Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
)?;
|
||||||
|
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
|
||||||
|
assert_eq!(actual, vec![vec![2i64, 3, 4]]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
553
candle-transformers/src/models/depth_anything_v2.rs
Normal file
553
candle-transformers/src/models/depth_anything_v2.rs
Normal file
@ -0,0 +1,553 @@
|
|||||||
|
use candle::D::Minus1;
|
||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
use candle_nn::ops::Identity;
|
||||||
|
use candle_nn::{
|
||||||
|
batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
|
||||||
|
BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::models::dinov2::DinoVisionTransformer;
|
||||||
|
|
||||||
|
pub struct DepthAnythingV2Config {
|
||||||
|
out_channel_sizes: [usize; 4],
|
||||||
|
in_channel_size: usize, // embed_dim in the Dino model
|
||||||
|
num_features: usize,
|
||||||
|
use_batch_norm: bool,
|
||||||
|
use_class_token: bool,
|
||||||
|
layer_ids_vits: Vec<usize>,
|
||||||
|
input_image_size: usize,
|
||||||
|
target_patch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DepthAnythingV2Config {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn new(
|
||||||
|
out_channel_sizes: [usize; 4],
|
||||||
|
in_channel_size: usize,
|
||||||
|
num_features: usize,
|
||||||
|
use_batch_norm: bool,
|
||||||
|
use_class_token: bool,
|
||||||
|
layer_ids_vits: Vec<usize>,
|
||||||
|
input_image_size: usize,
|
||||||
|
target_patch_size: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes,
|
||||||
|
in_channel_size,
|
||||||
|
num_features,
|
||||||
|
use_batch_norm,
|
||||||
|
use_class_token,
|
||||||
|
layer_ids_vits,
|
||||||
|
input_image_size,
|
||||||
|
target_patch_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_small() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [48, 96, 192, 384],
|
||||||
|
in_channel_size: 384,
|
||||||
|
num_features: 64,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![2, 5, 8, 11],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_base() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [96, 192, 384, 768],
|
||||||
|
in_channel_size: 768,
|
||||||
|
num_features: 128,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![2, 5, 8, 11],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_large() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [256, 512, 1024, 1024],
|
||||||
|
in_channel_size: 1024,
|
||||||
|
num_features: 256,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![4, 11, 17, 23],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_giant() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [1536, 1536, 1536, 1536],
|
||||||
|
in_channel_size: 1536,
|
||||||
|
num_features: 384,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![9, 19, 29, 39],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ResidualConvUnit {
|
||||||
|
activation: Activation,
|
||||||
|
conv1: Conv2d,
|
||||||
|
conv2: Conv2d,
|
||||||
|
batch_norm1: Option<BatchNorm>,
|
||||||
|
batch_norm2: Option<BatchNorm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResidualConvUnit {
|
||||||
|
pub fn new(
|
||||||
|
conf: &DepthAnythingV2Config,
|
||||||
|
activation: Activation,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
const KERNEL_SIZE: usize = 3;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
let conv1 = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("conv1"),
|
||||||
|
)?;
|
||||||
|
let conv2 = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("conv2"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
|
||||||
|
true => {
|
||||||
|
let batch_norm_cfg = BatchNormConfig {
|
||||||
|
eps: 1e-05,
|
||||||
|
remove_mean: false,
|
||||||
|
affine: true,
|
||||||
|
momentum: 0.1,
|
||||||
|
};
|
||||||
|
(
|
||||||
|
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
|
||||||
|
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
false => (None, None),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
activation,
|
||||||
|
conv1,
|
||||||
|
conv2,
|
||||||
|
batch_norm1,
|
||||||
|
batch_norm2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ResidualConvUnit {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let out = self.activation.forward(xs)?;
|
||||||
|
let out = self.conv1.forward(&out)?;
|
||||||
|
let out = if let Some(batch_norm1) = &self.batch_norm1 {
|
||||||
|
batch_norm1.forward_train(&out)?
|
||||||
|
} else {
|
||||||
|
out
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = self.activation.forward(&out)?;
|
||||||
|
let out = self.conv2.forward(&out)?;
|
||||||
|
let out = if let Some(batch_norm2) = &self.batch_norm2 {
|
||||||
|
batch_norm2.forward_train(&out)?
|
||||||
|
} else {
|
||||||
|
out
|
||||||
|
};
|
||||||
|
|
||||||
|
out + xs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FeatureFusionBlock {
|
||||||
|
res_conv_unit1: ResidualConvUnit,
|
||||||
|
res_conv_unit2: ResidualConvUnit,
|
||||||
|
output_conv: Conv2d,
|
||||||
|
target_patch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FeatureFusionBlock {
|
||||||
|
pub fn new(
|
||||||
|
conf: &DepthAnythingV2Config,
|
||||||
|
target_patch_size: usize,
|
||||||
|
activation: Activation,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
const KERNEL_SIZE: usize = 1;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
let output_conv = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("out_conv"),
|
||||||
|
)?;
|
||||||
|
let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
|
||||||
|
let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
res_conv_unit1,
|
||||||
|
res_conv_unit2,
|
||||||
|
output_conv,
|
||||||
|
target_patch_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for FeatureFusionBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let out = self.res_conv_unit2.forward(xs)?;
|
||||||
|
let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
|
||||||
|
|
||||||
|
self.output_conv.forward(&out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Scratch {
|
||||||
|
layer1_rn: Conv2d,
|
||||||
|
layer2_rn: Conv2d,
|
||||||
|
layer3_rn: Conv2d,
|
||||||
|
layer4_rn: Conv2d,
|
||||||
|
refine_net1: FeatureFusionBlock,
|
||||||
|
refine_net2: FeatureFusionBlock,
|
||||||
|
refine_net3: FeatureFusionBlock,
|
||||||
|
refine_net4: FeatureFusionBlock,
|
||||||
|
output_conv1: Conv2d,
|
||||||
|
output_conv2: Sequential,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scratch {
|
||||||
|
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
const KERNEL_SIZE: usize = 3;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let layer1_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[0],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer1_rn"),
|
||||||
|
)?;
|
||||||
|
let layer2_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[1],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer2_rn"),
|
||||||
|
)?;
|
||||||
|
let layer3_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[2],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer3_rn"),
|
||||||
|
)?;
|
||||||
|
let layer4_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[3],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer4_rn"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let refine_net1 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size * 8,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet1"),
|
||||||
|
)?;
|
||||||
|
let refine_net2 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size * 4,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet2"),
|
||||||
|
)?;
|
||||||
|
let refine_net3 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size * 2,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet3"),
|
||||||
|
)?;
|
||||||
|
let refine_net4 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet4"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
let output_conv1 = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features / 2,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("output_conv1"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let output_conv2 = seq();
|
||||||
|
const HEAD_FEATURES_2: usize = 32;
|
||||||
|
const OUT_CHANNELS_2: usize = 1;
|
||||||
|
const KERNEL_SIZE_2: usize = 1;
|
||||||
|
let output_conv2 = output_conv2.add(conv2d(
|
||||||
|
conf.num_features / 2,
|
||||||
|
HEAD_FEATURES_2,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("output_conv2").pp("0"),
|
||||||
|
)?);
|
||||||
|
let output_conv2 = output_conv2
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(conv2d(
|
||||||
|
HEAD_FEATURES_2,
|
||||||
|
OUT_CHANNELS_2,
|
||||||
|
KERNEL_SIZE_2,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("output_conv2").pp("2"),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
layer1_rn,
|
||||||
|
layer2_rn,
|
||||||
|
layer3_rn,
|
||||||
|
layer4_rn,
|
||||||
|
refine_net1,
|
||||||
|
refine_net2,
|
||||||
|
refine_net3,
|
||||||
|
refine_net4,
|
||||||
|
output_conv1,
|
||||||
|
output_conv2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const NUM_CHANNELS: usize = 4;
|
||||||
|
|
||||||
|
pub struct DPTHead<'a> {
|
||||||
|
conf: &'a DepthAnythingV2Config,
|
||||||
|
projections: Vec<Conv2d>,
|
||||||
|
resize_layers: Vec<Box<dyn Module>>,
|
||||||
|
readout_projections: Vec<Sequential>,
|
||||||
|
scratch: Scratch,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DPTHead<'a> {
|
||||||
|
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
|
||||||
|
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
|
||||||
|
projections.push(conv2d(
|
||||||
|
conf.in_channel_size,
|
||||||
|
*out_channel_size,
|
||||||
|
1,
|
||||||
|
Default::default(),
|
||||||
|
vb.pp("projects").pp(conv_index.to_string()),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resize_layers: Vec<Box<dyn Module>> = vec![
|
||||||
|
Box::new(conv_transpose2d(
|
||||||
|
conf.out_channel_sizes[0],
|
||||||
|
conf.out_channel_sizes[0],
|
||||||
|
4,
|
||||||
|
ConvTranspose2dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride: 4,
|
||||||
|
dilation: 1,
|
||||||
|
output_padding: 0,
|
||||||
|
},
|
||||||
|
vb.pp("resize_layers").pp("0"),
|
||||||
|
)?),
|
||||||
|
Box::new(conv_transpose2d(
|
||||||
|
conf.out_channel_sizes[1],
|
||||||
|
conf.out_channel_sizes[1],
|
||||||
|
2,
|
||||||
|
ConvTranspose2dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride: 2,
|
||||||
|
dilation: 1,
|
||||||
|
output_padding: 0,
|
||||||
|
},
|
||||||
|
vb.pp("resize_layers").pp("1"),
|
||||||
|
)?),
|
||||||
|
Box::new(Identity::new()),
|
||||||
|
Box::new(conv2d(
|
||||||
|
conf.out_channel_sizes[3],
|
||||||
|
conf.out_channel_sizes[3],
|
||||||
|
3,
|
||||||
|
Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 2,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
},
|
||||||
|
vb.pp("resize_layers").pp("3"),
|
||||||
|
)?),
|
||||||
|
];
|
||||||
|
|
||||||
|
let readout_projections = if conf.use_class_token {
|
||||||
|
let rop = Vec::with_capacity(NUM_CHANNELS);
|
||||||
|
for rop_index in 0..NUM_CHANNELS {
|
||||||
|
seq()
|
||||||
|
.add(linear(
|
||||||
|
2 * conf.in_channel_size,
|
||||||
|
conf.in_channel_size,
|
||||||
|
vb.pp("readout_projects").pp(rop_index.to_string()),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Gelu);
|
||||||
|
}
|
||||||
|
rop
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
conf,
|
||||||
|
projections,
|
||||||
|
resize_layers,
|
||||||
|
readout_projections,
|
||||||
|
scratch,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DPTHead<'_> {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
|
||||||
|
for i in 0..NUM_CHANNELS {
|
||||||
|
let x = if self.conf.use_class_token {
|
||||||
|
let x = xs.get(i)?.get(0)?;
|
||||||
|
let class_token = xs.get(i)?.get(1)?;
|
||||||
|
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
|
||||||
|
let to_cat = [x, readout];
|
||||||
|
let cat = Tensor::cat(&to_cat, Minus1)?;
|
||||||
|
self.readout_projections[i].forward(&cat)?
|
||||||
|
} else {
|
||||||
|
xs.get(i)?
|
||||||
|
};
|
||||||
|
let x_dims = x.dims();
|
||||||
|
|
||||||
|
let x = x.permute((0, 2, 1))?.reshape((
|
||||||
|
x_dims[0],
|
||||||
|
x_dims[x_dims.len() - 1],
|
||||||
|
self.conf.target_patch_size,
|
||||||
|
self.conf.target_patch_size,
|
||||||
|
))?;
|
||||||
|
let x = self.projections[i].forward(&x)?;
|
||||||
|
|
||||||
|
let x = self.resize_layers[i].forward(&x)?;
|
||||||
|
out.push(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
|
||||||
|
let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
|
||||||
|
let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
|
||||||
|
let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
|
||||||
|
|
||||||
|
let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
|
||||||
|
|
||||||
|
let res3_out = self
|
||||||
|
.scratch
|
||||||
|
.refine_net3
|
||||||
|
.res_conv_unit1
|
||||||
|
.forward(&layer_3_rn)?;
|
||||||
|
let res3_out = path4.add(&res3_out)?;
|
||||||
|
let path3 = self.scratch.refine_net3.forward(&res3_out)?;
|
||||||
|
|
||||||
|
let res2_out = self
|
||||||
|
.scratch
|
||||||
|
.refine_net2
|
||||||
|
.res_conv_unit1
|
||||||
|
.forward(&layer_2_rn)?;
|
||||||
|
let res2_out = path3.add(&res2_out)?;
|
||||||
|
let path2 = self.scratch.refine_net2.forward(&res2_out)?;
|
||||||
|
|
||||||
|
let res1_out = self
|
||||||
|
.scratch
|
||||||
|
.refine_net1
|
||||||
|
.res_conv_unit1
|
||||||
|
.forward(&layer_1_rn)?;
|
||||||
|
let res1_out = path2.add(&res1_out)?;
|
||||||
|
let path1 = self.scratch.refine_net1.forward(&res1_out)?;
|
||||||
|
|
||||||
|
let out = self.scratch.output_conv1.forward(&path1)?;
|
||||||
|
|
||||||
|
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
|
||||||
|
|
||||||
|
self.scratch.output_conv2.forward(&out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DepthAnythingV2<'a> {
|
||||||
|
pretrained: &'a DinoVisionTransformer,
|
||||||
|
depth_head: DPTHead<'a>,
|
||||||
|
conf: &'a DepthAnythingV2Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DepthAnythingV2<'a> {
|
||||||
|
pub fn new(
|
||||||
|
pretrained: &'a DinoVisionTransformer,
|
||||||
|
conf: &'a DepthAnythingV2Config,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
pretrained,
|
||||||
|
depth_head,
|
||||||
|
conf,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Module for DepthAnythingV2<'a> {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let features = self.pretrained.get_intermediate_layers(
|
||||||
|
xs,
|
||||||
|
&self.conf.layer_ids_vits,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
)?;
|
||||||
|
let depth = self.depth_head.forward(&features)?;
|
||||||
|
|
||||||
|
depth.relu()
|
||||||
|
}
|
||||||
|
}
|
@ -258,6 +258,84 @@ impl DinoVisionTransformer {
|
|||||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_intermediate_layers_not_chunked(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
blocks_to_take: &[usize],
|
||||||
|
) -> Result<Vec<Tensor>> {
|
||||||
|
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||||
|
let mut output = Vec::new();
|
||||||
|
for (i, blk) in self.blocks.iter().enumerate() {
|
||||||
|
xs = blk.forward(&xs)?;
|
||||||
|
if blocks_to_take.contains(&i) {
|
||||||
|
output.push(xs.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if output.len() != blocks_to_take.len() {
|
||||||
|
candle::bail!(
|
||||||
|
"only {} / {} blocks found",
|
||||||
|
output.len(),
|
||||||
|
blocks_to_take.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_intermediate_layers(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
blocks_to_take: &[usize],
|
||||||
|
reshape: bool,
|
||||||
|
return_class_token: bool,
|
||||||
|
norm: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
|
||||||
|
let outputs = if norm {
|
||||||
|
outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| self.norm.forward(out))
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
} else {
|
||||||
|
outputs
|
||||||
|
};
|
||||||
|
let class_tokens = outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| out.i((.., 0)))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let outputs = outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| out.i((.., 1..)))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let outputs = if reshape {
|
||||||
|
let (b, _c, w, h) = xs.dims4()?;
|
||||||
|
let patch_size = self.patch_embed.patch_size.0;
|
||||||
|
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
|
||||||
|
outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| {
|
||||||
|
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
|
||||||
|
.transpose(2, 3)?
|
||||||
|
.transpose(1, 2)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
} else {
|
||||||
|
outputs
|
||||||
|
};
|
||||||
|
|
||||||
|
let outputs = if return_class_token {
|
||||||
|
outputs
|
||||||
|
.iter()
|
||||||
|
.zip(class_tokens.iter())
|
||||||
|
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
} else {
|
||||||
|
outputs
|
||||||
|
};
|
||||||
|
|
||||||
|
Tensor::stack(&outputs[..], 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for DinoVisionTransformer {
|
impl Module for DinoVisionTransformer {
|
||||||
|
@ -6,6 +6,7 @@ pub mod chatglm;
|
|||||||
pub mod clip;
|
pub mod clip;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod convnext;
|
pub mod convnext;
|
||||||
|
pub mod depth_anything_v2;
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
pub mod distilbert;
|
pub mod distilbert;
|
||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
|
@ -360,8 +360,12 @@ pub struct ModelForCausalLM {
|
|||||||
|
|
||||||
impl ModelForCausalLM {
|
impl ModelForCausalLM {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let base_model = Model::new(cfg, vb.clone())?;
|
||||||
let base_model = Model::new(cfg, vb)?;
|
let lm_head = if vb.contains_tensor("lm_head") {
|
||||||
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
} else {
|
||||||
|
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
base_model,
|
base_model,
|
||||||
lm_head,
|
lm_head,
|
||||||
|
@ -54,8 +54,7 @@ impl ModuleT for Vgg<'_> {
|
|||||||
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
|
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
|
||||||
let layers = convs
|
let layers = convs
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.map(|&(in_c, out_c, name)| {
|
||||||
.map(|(_, &(in_c, out_c, name))| {
|
|
||||||
candle_nn::conv2d(
|
candle_nn::conv2d(
|
||||||
in_c,
|
in_c,
|
||||||
out_c,
|
out_c,
|
||||||
|
Reference in New Issue
Block a user