mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Depth Anything v2 (#2279)
* define structs * construct ResidualConvUnit * forward() for ResidualConvUnit * implement FeatureFusionBlock * implement Scratch * implement DPTHead * add identity module * implement forward for DTPHead * add get_intermediate_layers to DinoVisionTransformer * implement DepthAnythingV2 * some minor tweaks * fix compile errors * fix var builder prefixes * setup initial example * use fixed patch size of 37 (518 / 14) * debugged until output * print min and max values * add some dynamism to the output location * scale input image * extract prep function * extract output path function * normalize image with magic mean and std * add spectral coloring * squeeze in the right place * make enterpolation optional * use bail instead of panic * omit unnecessary Shape call * remove empty curly braces * use bail instead of assert * use vb and pp * remove closures * extract config object * Apply rustfmt. * Fix some clippy lints. * More lints. * Use the array methods. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
|||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
palette = { version = "0.7.6", optional = true }
|
||||||
|
enterpolation = { version = "0.2.1", optional = true}
|
||||||
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
rubato = { version = "0.15.0", optional = true }
|
rubato = { version = "0.15.0", optional = true }
|
||||||
@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
|
|||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
microphone = ["cpal"]
|
microphone = ["cpal"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["encodec"]
|
required-features = ["encodec"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "depth_anything_v2"
|
||||||
|
required-features = ["depth_anything_v2"]
|
||||||
|
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
13
candle-examples/examples/depth_anything_v2/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-dinov2
|
||||||
|
|
||||||
|
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
|
||||||
|
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.
|
||||||
|
|
||||||
|
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.
|
||||||
|
|
||||||
|
## Running an example with color map and CUDA
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
```
|
||||||
|
|
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
50
candle-examples/examples/depth_anything_v2/color_map.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use enterpolation::linear::ConstEquidistantLinear;
|
||||||
|
use enterpolation::Generator;
|
||||||
|
use palette::LinSrgb;
|
||||||
|
|
||||||
|
use candle::Tensor;
|
||||||
|
|
||||||
|
pub struct SpectralRColormap {
|
||||||
|
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpectralRColormap {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// Define a colormap similar to 'Spectral_r' by specifying key colors.
|
||||||
|
// got the colors from ChatGPT-4o
|
||||||
|
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
|
||||||
|
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
|
||||||
|
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
|
||||||
|
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
|
||||||
|
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
|
||||||
|
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
|
||||||
|
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
|
||||||
|
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
|
||||||
|
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
|
||||||
|
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
|
||||||
|
]);
|
||||||
|
Self { gradient }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_color(&self, value: f32) -> LinSrgb {
|
||||||
|
self.gradient.gen(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
println!("Gray: {:?}", gray.dims());
|
||||||
|
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
|
||||||
|
let rgb_values: Vec<f32> = gray_values
|
||||||
|
.iter()
|
||||||
|
.map(|g| self.get_color(*g))
|
||||||
|
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let [.., height, width] = gray.dims() else {
|
||||||
|
candle::bail!("Not enough dims!")
|
||||||
|
};
|
||||||
|
|
||||||
|
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;
|
||||||
|
|
||||||
|
color.permute((2, 0, 1))
|
||||||
|
}
|
||||||
|
}
|
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
187
candle-examples/examples/depth_anything_v2/main.rs
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
//! Depth Anything V2
|
||||||
|
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use std::ffi::OsString;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::DType::{F32, U8};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor};
|
||||||
|
use candle_examples::{load_image, load_image_and_resize, save_image};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
|
||||||
|
use candle_transformers::models::dinov2;
|
||||||
|
|
||||||
|
use crate::color_map::SpectralRColormap;
|
||||||
|
|
||||||
|
mod color_map;
|
||||||
|
|
||||||
|
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
|
||||||
|
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
|
||||||
|
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];
|
||||||
|
|
||||||
|
const DINO_IMG_SIZE: usize = 518;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
dinov2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
depth_anything_v2_model: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: PathBuf,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
output_dir: Option<PathBuf>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
color_map: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let dinov2_model_file = match args.dinov2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("lmz/candle-dino-v2".into());
|
||||||
|
api.get("dinov2_vits14.safetensors")?
|
||||||
|
}
|
||||||
|
Some(dinov2_model) => dinov2_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", dinov2_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
|
||||||
|
let dinov2 = dinov2::vit_small(vb)?;
|
||||||
|
println!("DinoV2 model built");
|
||||||
|
|
||||||
|
let depth_anything_model_file = match args.depth_anything_v2_model {
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
|
||||||
|
api.get("depth_anything_v2_vits.safetensors")?
|
||||||
|
}
|
||||||
|
Some(depth_anything_model) => depth_anything_model,
|
||||||
|
};
|
||||||
|
println!("Using file {:?}", depth_anything_model_file);
|
||||||
|
|
||||||
|
let vb = unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = DepthAnythingV2Config::vit_small();
|
||||||
|
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
|
||||||
|
|
||||||
|
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;
|
||||||
|
|
||||||
|
println!("Loaded image {image:?}");
|
||||||
|
|
||||||
|
let depth = depth_anything.forward(&image)?;
|
||||||
|
|
||||||
|
println!("Got predictions {:?}", depth.shape());
|
||||||
|
|
||||||
|
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;
|
||||||
|
|
||||||
|
let output_path = full_output_path(&args.image, &args.output_dir);
|
||||||
|
println!("Saving image to {}", output_path.to_string_lossy());
|
||||||
|
save_image(&output_image, output_path)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
|
||||||
|
let input_file_name = image_path.file_name().unwrap();
|
||||||
|
let mut output_file_name = OsString::from("depth_");
|
||||||
|
output_file_name.push(input_file_name);
|
||||||
|
let mut output_path = match output_dir {
|
||||||
|
None => image_path.parent().unwrap().to_path_buf(),
|
||||||
|
Some(output_path) => output_path.clone(),
|
||||||
|
};
|
||||||
|
output_path.push(output_file_name);
|
||||||
|
|
||||||
|
output_path
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_and_prep_image(
|
||||||
|
image_path: &PathBuf,
|
||||||
|
device: &Device,
|
||||||
|
) -> anyhow::Result<(usize, usize, Tensor)> {
|
||||||
|
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;
|
||||||
|
|
||||||
|
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.to_dtype(F32)?
|
||||||
|
.to_device(&device)?;
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(&device)?
|
||||||
|
.broadcast_as(image.shape())?;
|
||||||
|
let image = (image / max_pixel_val)?;
|
||||||
|
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;
|
||||||
|
|
||||||
|
Ok((original_height, original_width, image))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
|
||||||
|
let mean_tensor =
|
||||||
|
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
let std_tensor =
|
||||||
|
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
|
||||||
|
image.sub(&mean_tensor)?.div(&std_tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_process_image(
|
||||||
|
image: &Tensor,
|
||||||
|
original_height: usize,
|
||||||
|
original_width: usize,
|
||||||
|
color_map: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let out = image.interpolate2d(original_height, original_width)?;
|
||||||
|
let out = scale_image(&out)?;
|
||||||
|
|
||||||
|
let out = if color_map {
|
||||||
|
let spectral_r = SpectralRColormap::new();
|
||||||
|
spectral_r.gray2color(&out)?
|
||||||
|
} else {
|
||||||
|
let rgb_slice = [&out, &out, &out];
|
||||||
|
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_pixel_val = Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(out.device())?
|
||||||
|
.broadcast_as(out.shape())?;
|
||||||
|
let out = (out * max_pixel_val)?;
|
||||||
|
|
||||||
|
out.to_dtype(U8)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scale_image(depth: &Tensor) -> Result<Tensor> {
|
||||||
|
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;
|
||||||
|
|
||||||
|
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
|
||||||
|
|
||||||
|
let min_val_tensor = Tensor::try_from(*min_val)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
let depth = (depth - min_val_tensor)?;
|
||||||
|
|
||||||
|
let range = max_val - min_val;
|
||||||
|
let range_tensor = Tensor::try_from(range)?
|
||||||
|
.to_device(depth.device())?
|
||||||
|
.broadcast_as(depth.shape())?;
|
||||||
|
|
||||||
|
depth / range_tensor
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D};
|
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||||
@ -926,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
|
|||||||
n => candle::bail!("replication-pad with a size of {n} is not supported"),
|
n => candle::bail!("replication-pad with a size of {n} is not supported"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Identity;
|
||||||
|
|
||||||
|
impl Identity {
|
||||||
|
pub fn new() -> Identity {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Identity {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Identity {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
Ok(xs.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
553
candle-transformers/src/models/depth_anything_v2.rs
Normal file
553
candle-transformers/src/models/depth_anything_v2.rs
Normal file
@ -0,0 +1,553 @@
|
|||||||
|
use candle::D::Minus1;
|
||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
use candle_nn::ops::Identity;
|
||||||
|
use candle_nn::{
|
||||||
|
batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm,
|
||||||
|
BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::models::dinov2::DinoVisionTransformer;
|
||||||
|
|
||||||
|
pub struct DepthAnythingV2Config {
|
||||||
|
out_channel_sizes: [usize; 4],
|
||||||
|
in_channel_size: usize, // embed_dim in the Dino model
|
||||||
|
num_features: usize,
|
||||||
|
use_batch_norm: bool,
|
||||||
|
use_class_token: bool,
|
||||||
|
layer_ids_vits: Vec<usize>,
|
||||||
|
input_image_size: usize,
|
||||||
|
target_patch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DepthAnythingV2Config {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn new(
|
||||||
|
out_channel_sizes: [usize; 4],
|
||||||
|
in_channel_size: usize,
|
||||||
|
num_features: usize,
|
||||||
|
use_batch_norm: bool,
|
||||||
|
use_class_token: bool,
|
||||||
|
layer_ids_vits: Vec<usize>,
|
||||||
|
input_image_size: usize,
|
||||||
|
target_patch_size: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes,
|
||||||
|
in_channel_size,
|
||||||
|
num_features,
|
||||||
|
use_batch_norm,
|
||||||
|
use_class_token,
|
||||||
|
layer_ids_vits,
|
||||||
|
input_image_size,
|
||||||
|
target_patch_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_small() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [48, 96, 192, 384],
|
||||||
|
in_channel_size: 384,
|
||||||
|
num_features: 64,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![2, 5, 8, 11],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_base() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [96, 192, 384, 768],
|
||||||
|
in_channel_size: 768,
|
||||||
|
num_features: 128,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![2, 5, 8, 11],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_large() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [256, 512, 1024, 1024],
|
||||||
|
in_channel_size: 1024,
|
||||||
|
num_features: 256,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![4, 11, 17, 23],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn vit_giant() -> Self {
|
||||||
|
Self {
|
||||||
|
out_channel_sizes: [1536, 1536, 1536, 1536],
|
||||||
|
in_channel_size: 1536,
|
||||||
|
num_features: 384,
|
||||||
|
use_batch_norm: false,
|
||||||
|
use_class_token: false,
|
||||||
|
layer_ids_vits: vec![9, 19, 29, 39],
|
||||||
|
input_image_size: 518,
|
||||||
|
target_patch_size: 518 / 14,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ResidualConvUnit {
|
||||||
|
activation: Activation,
|
||||||
|
conv1: Conv2d,
|
||||||
|
conv2: Conv2d,
|
||||||
|
batch_norm1: Option<BatchNorm>,
|
||||||
|
batch_norm2: Option<BatchNorm>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResidualConvUnit {
|
||||||
|
pub fn new(
|
||||||
|
conf: &DepthAnythingV2Config,
|
||||||
|
activation: Activation,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
const KERNEL_SIZE: usize = 3;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
let conv1 = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("conv1"),
|
||||||
|
)?;
|
||||||
|
let conv2 = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("conv2"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let (batch_norm1, batch_norm2) = match conf.use_batch_norm {
|
||||||
|
true => {
|
||||||
|
let batch_norm_cfg = BatchNormConfig {
|
||||||
|
eps: 1e-05,
|
||||||
|
remove_mean: false,
|
||||||
|
affine: true,
|
||||||
|
momentum: 0.1,
|
||||||
|
};
|
||||||
|
(
|
||||||
|
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?),
|
||||||
|
Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
false => (None, None),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
activation,
|
||||||
|
conv1,
|
||||||
|
conv2,
|
||||||
|
batch_norm1,
|
||||||
|
batch_norm2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ResidualConvUnit {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let out = self.activation.forward(xs)?;
|
||||||
|
let out = self.conv1.forward(&out)?;
|
||||||
|
let out = if let Some(batch_norm1) = &self.batch_norm1 {
|
||||||
|
batch_norm1.forward_train(&out)?
|
||||||
|
} else {
|
||||||
|
out
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = self.activation.forward(&out)?;
|
||||||
|
let out = self.conv2.forward(&out)?;
|
||||||
|
let out = if let Some(batch_norm2) = &self.batch_norm2 {
|
||||||
|
batch_norm2.forward_train(&out)?
|
||||||
|
} else {
|
||||||
|
out
|
||||||
|
};
|
||||||
|
|
||||||
|
out + xs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FeatureFusionBlock {
|
||||||
|
res_conv_unit1: ResidualConvUnit,
|
||||||
|
res_conv_unit2: ResidualConvUnit,
|
||||||
|
output_conv: Conv2d,
|
||||||
|
target_patch_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FeatureFusionBlock {
|
||||||
|
pub fn new(
|
||||||
|
conf: &DepthAnythingV2Config,
|
||||||
|
target_patch_size: usize,
|
||||||
|
activation: Activation,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
const KERNEL_SIZE: usize = 1;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
let output_conv = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("out_conv"),
|
||||||
|
)?;
|
||||||
|
let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?;
|
||||||
|
let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
res_conv_unit1,
|
||||||
|
res_conv_unit2,
|
||||||
|
output_conv,
|
||||||
|
target_patch_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for FeatureFusionBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let out = self.res_conv_unit2.forward(xs)?;
|
||||||
|
let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?;
|
||||||
|
|
||||||
|
self.output_conv.forward(&out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Scratch {
|
||||||
|
layer1_rn: Conv2d,
|
||||||
|
layer2_rn: Conv2d,
|
||||||
|
layer3_rn: Conv2d,
|
||||||
|
layer4_rn: Conv2d,
|
||||||
|
refine_net1: FeatureFusionBlock,
|
||||||
|
refine_net2: FeatureFusionBlock,
|
||||||
|
refine_net3: FeatureFusionBlock,
|
||||||
|
refine_net4: FeatureFusionBlock,
|
||||||
|
output_conv1: Conv2d,
|
||||||
|
output_conv2: Sequential,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scratch {
|
||||||
|
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
const KERNEL_SIZE: usize = 3;
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let layer1_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[0],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer1_rn"),
|
||||||
|
)?;
|
||||||
|
let layer2_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[1],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer2_rn"),
|
||||||
|
)?;
|
||||||
|
let layer3_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[2],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer3_rn"),
|
||||||
|
)?;
|
||||||
|
let layer4_rn = conv2d_no_bias(
|
||||||
|
conf.out_channel_sizes[3],
|
||||||
|
conf.num_features,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("layer4_rn"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let refine_net1 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size * 8,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet1"),
|
||||||
|
)?;
|
||||||
|
let refine_net2 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size * 4,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet2"),
|
||||||
|
)?;
|
||||||
|
let refine_net3 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size * 2,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet3"),
|
||||||
|
)?;
|
||||||
|
let refine_net4 = FeatureFusionBlock::new(
|
||||||
|
conf,
|
||||||
|
conf.target_patch_size,
|
||||||
|
Activation::Relu,
|
||||||
|
vb.pp("refinenet4"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let conv_cfg = Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
};
|
||||||
|
let output_conv1 = conv2d(
|
||||||
|
conf.num_features,
|
||||||
|
conf.num_features / 2,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("output_conv1"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let output_conv2 = seq();
|
||||||
|
const HEAD_FEATURES_2: usize = 32;
|
||||||
|
const OUT_CHANNELS_2: usize = 1;
|
||||||
|
const KERNEL_SIZE_2: usize = 1;
|
||||||
|
let output_conv2 = output_conv2.add(conv2d(
|
||||||
|
conf.num_features / 2,
|
||||||
|
HEAD_FEATURES_2,
|
||||||
|
KERNEL_SIZE,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("output_conv2").pp("0"),
|
||||||
|
)?);
|
||||||
|
let output_conv2 = output_conv2
|
||||||
|
.add(Activation::Relu)
|
||||||
|
.add(conv2d(
|
||||||
|
HEAD_FEATURES_2,
|
||||||
|
OUT_CHANNELS_2,
|
||||||
|
KERNEL_SIZE_2,
|
||||||
|
conv_cfg,
|
||||||
|
vb.pp("output_conv2").pp("2"),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Relu);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
layer1_rn,
|
||||||
|
layer2_rn,
|
||||||
|
layer3_rn,
|
||||||
|
layer4_rn,
|
||||||
|
refine_net1,
|
||||||
|
refine_net2,
|
||||||
|
refine_net3,
|
||||||
|
refine_net4,
|
||||||
|
output_conv1,
|
||||||
|
output_conv2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const NUM_CHANNELS: usize = 4;
|
||||||
|
|
||||||
|
pub struct DPTHead<'a> {
|
||||||
|
conf: &'a DepthAnythingV2Config,
|
||||||
|
projections: Vec<Conv2d>,
|
||||||
|
resize_layers: Vec<Box<dyn Module>>,
|
||||||
|
readout_projections: Vec<Sequential>,
|
||||||
|
scratch: Scratch,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DPTHead<'a> {
|
||||||
|
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
|
||||||
|
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
|
||||||
|
projections.push(conv2d(
|
||||||
|
conf.in_channel_size,
|
||||||
|
*out_channel_size,
|
||||||
|
1,
|
||||||
|
Default::default(),
|
||||||
|
vb.pp("projects").pp(conv_index.to_string()),
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resize_layers: Vec<Box<dyn Module>> = vec![
|
||||||
|
Box::new(conv_transpose2d(
|
||||||
|
conf.out_channel_sizes[0],
|
||||||
|
conf.out_channel_sizes[0],
|
||||||
|
4,
|
||||||
|
ConvTranspose2dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride: 4,
|
||||||
|
dilation: 1,
|
||||||
|
output_padding: 0,
|
||||||
|
},
|
||||||
|
vb.pp("resize_layers").pp("0"),
|
||||||
|
)?),
|
||||||
|
Box::new(conv_transpose2d(
|
||||||
|
conf.out_channel_sizes[1],
|
||||||
|
conf.out_channel_sizes[1],
|
||||||
|
2,
|
||||||
|
ConvTranspose2dConfig {
|
||||||
|
padding: 0,
|
||||||
|
stride: 2,
|
||||||
|
dilation: 1,
|
||||||
|
output_padding: 0,
|
||||||
|
},
|
||||||
|
vb.pp("resize_layers").pp("1"),
|
||||||
|
)?),
|
||||||
|
Box::new(Identity::new()),
|
||||||
|
Box::new(conv2d(
|
||||||
|
conf.out_channel_sizes[3],
|
||||||
|
conf.out_channel_sizes[3],
|
||||||
|
3,
|
||||||
|
Conv2dConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 2,
|
||||||
|
dilation: 1,
|
||||||
|
groups: 1,
|
||||||
|
},
|
||||||
|
vb.pp("resize_layers").pp("3"),
|
||||||
|
)?),
|
||||||
|
];
|
||||||
|
|
||||||
|
let readout_projections = if conf.use_class_token {
|
||||||
|
let rop = Vec::with_capacity(NUM_CHANNELS);
|
||||||
|
for rop_index in 0..NUM_CHANNELS {
|
||||||
|
seq()
|
||||||
|
.add(linear(
|
||||||
|
2 * conf.in_channel_size,
|
||||||
|
conf.in_channel_size,
|
||||||
|
vb.pp("readout_projects").pp(rop_index.to_string()),
|
||||||
|
)?)
|
||||||
|
.add(Activation::Gelu);
|
||||||
|
}
|
||||||
|
rop
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
conf,
|
||||||
|
projections,
|
||||||
|
resize_layers,
|
||||||
|
readout_projections,
|
||||||
|
scratch,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DPTHead<'_> {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
|
||||||
|
for i in 0..NUM_CHANNELS {
|
||||||
|
let x = if self.conf.use_class_token {
|
||||||
|
let x = xs.get(i)?.get(0)?;
|
||||||
|
let class_token = xs.get(i)?.get(1)?;
|
||||||
|
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
|
||||||
|
let to_cat = [x, readout];
|
||||||
|
let cat = Tensor::cat(&to_cat, Minus1)?;
|
||||||
|
self.readout_projections[i].forward(&cat)?
|
||||||
|
} else {
|
||||||
|
xs.get(i)?
|
||||||
|
};
|
||||||
|
let x_dims = x.dims();
|
||||||
|
|
||||||
|
let x = x.permute((0, 2, 1))?.reshape((
|
||||||
|
x_dims[0],
|
||||||
|
x_dims[x_dims.len() - 1],
|
||||||
|
self.conf.target_patch_size,
|
||||||
|
self.conf.target_patch_size,
|
||||||
|
))?;
|
||||||
|
let x = self.projections[i].forward(&x)?;
|
||||||
|
|
||||||
|
let x = self.resize_layers[i].forward(&x)?;
|
||||||
|
out.push(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?;
|
||||||
|
let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?;
|
||||||
|
let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?;
|
||||||
|
let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?;
|
||||||
|
|
||||||
|
let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?;
|
||||||
|
|
||||||
|
let res3_out = self
|
||||||
|
.scratch
|
||||||
|
.refine_net3
|
||||||
|
.res_conv_unit1
|
||||||
|
.forward(&layer_3_rn)?;
|
||||||
|
let res3_out = path4.add(&res3_out)?;
|
||||||
|
let path3 = self.scratch.refine_net3.forward(&res3_out)?;
|
||||||
|
|
||||||
|
let res2_out = self
|
||||||
|
.scratch
|
||||||
|
.refine_net2
|
||||||
|
.res_conv_unit1
|
||||||
|
.forward(&layer_2_rn)?;
|
||||||
|
let res2_out = path3.add(&res2_out)?;
|
||||||
|
let path2 = self.scratch.refine_net2.forward(&res2_out)?;
|
||||||
|
|
||||||
|
let res1_out = self
|
||||||
|
.scratch
|
||||||
|
.refine_net1
|
||||||
|
.res_conv_unit1
|
||||||
|
.forward(&layer_1_rn)?;
|
||||||
|
let res1_out = path2.add(&res1_out)?;
|
||||||
|
let path1 = self.scratch.refine_net1.forward(&res1_out)?;
|
||||||
|
|
||||||
|
let out = self.scratch.output_conv1.forward(&path1)?;
|
||||||
|
|
||||||
|
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
|
||||||
|
|
||||||
|
self.scratch.output_conv2.forward(&out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DepthAnythingV2<'a> {
|
||||||
|
pretrained: &'a DinoVisionTransformer,
|
||||||
|
depth_head: DPTHead<'a>,
|
||||||
|
conf: &'a DepthAnythingV2Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> DepthAnythingV2<'a> {
|
||||||
|
pub fn new(
|
||||||
|
pretrained: &'a DinoVisionTransformer,
|
||||||
|
conf: &'a DepthAnythingV2Config,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
pretrained,
|
||||||
|
depth_head,
|
||||||
|
conf,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Module for DepthAnythingV2<'a> {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let features = self.pretrained.get_intermediate_layers(
|
||||||
|
xs,
|
||||||
|
&self.conf.layer_ids_vits,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
)?;
|
||||||
|
let depth = self.depth_head.forward(&features)?;
|
||||||
|
|
||||||
|
depth.relu()
|
||||||
|
}
|
||||||
|
}
|
@ -258,6 +258,84 @@ impl DinoVisionTransformer {
|
|||||||
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
|
||||||
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_intermediate_layers_not_chunked(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
blocks_to_take: &[usize],
|
||||||
|
) -> Result<Vec<Tensor>> {
|
||||||
|
let mut xs = self.prepare_tokens_with_mask(xs)?;
|
||||||
|
let mut output = Vec::new();
|
||||||
|
for (i, blk) in self.blocks.iter().enumerate() {
|
||||||
|
xs = blk.forward(&xs)?;
|
||||||
|
if blocks_to_take.contains(&i) {
|
||||||
|
output.push(xs.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if output.len() != blocks_to_take.len() {
|
||||||
|
candle::bail!(
|
||||||
|
"only {} / {} blocks found",
|
||||||
|
output.len(),
|
||||||
|
blocks_to_take.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_intermediate_layers(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
blocks_to_take: &[usize],
|
||||||
|
reshape: bool,
|
||||||
|
return_class_token: bool,
|
||||||
|
norm: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
|
||||||
|
let outputs = if norm {
|
||||||
|
outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| self.norm.forward(out))
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
} else {
|
||||||
|
outputs
|
||||||
|
};
|
||||||
|
let class_tokens = outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| out.i((.., 0)))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let outputs = outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| out.i((.., 1..)))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let outputs = if reshape {
|
||||||
|
let (b, _c, w, h) = xs.dims4()?;
|
||||||
|
let patch_size = self.patch_embed.patch_size.0;
|
||||||
|
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
|
||||||
|
outputs
|
||||||
|
.iter()
|
||||||
|
.map(|out| {
|
||||||
|
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
|
||||||
|
.transpose(2, 3)?
|
||||||
|
.transpose(1, 2)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
} else {
|
||||||
|
outputs
|
||||||
|
};
|
||||||
|
|
||||||
|
let outputs = if return_class_token {
|
||||||
|
outputs
|
||||||
|
.iter()
|
||||||
|
.zip(class_tokens.iter())
|
||||||
|
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
} else {
|
||||||
|
outputs
|
||||||
|
};
|
||||||
|
|
||||||
|
Tensor::stack(&outputs[..], 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for DinoVisionTransformer {
|
impl Module for DinoVisionTransformer {
|
||||||
|
@ -6,6 +6,7 @@ pub mod chatglm;
|
|||||||
pub mod clip;
|
pub mod clip;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod convnext;
|
pub mod convnext;
|
||||||
|
pub mod depth_anything_v2;
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
pub mod distilbert;
|
pub mod distilbert;
|
||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
|
Reference in New Issue
Block a user