make DepthAnythingV2 more reusable (#2675)

* make DepthAnythingV2 more reusable

* Fix clippy lints.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Edgar Riba
2024-12-21 12:06:03 +01:00
committed by GitHub
parent 67cab7d6b8
commit 5c2f893e5a
2 changed files with 27 additions and 23 deletions

View File

@ -6,10 +6,8 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use std::ffi::OsString;
use std::path::PathBuf;
use clap::Parser;
use std::{ffi::OsString, path::PathBuf, sync::Arc};
use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
@ -82,7 +80,7 @@ pub fn main() -> anyhow::Result<()> {
};
let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;
let depth_anything = DepthAnythingV2::new(Arc::new(dinov2), config, vb)?;
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;

View File

@ -4,6 +4,8 @@
//! - ["Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data"](https://github.com/LiheYoung/Depth-Anything)
//!
use std::sync::Arc;
use candle::D::Minus1;
use candle::{Module, Result, Tensor};
use candle_nn::ops::Identity;
@ -365,16 +367,18 @@ impl Scratch {
const NUM_CHANNELS: usize = 4;
pub struct DPTHead<'a> {
conf: &'a DepthAnythingV2Config,
pub struct DPTHead {
projections: Vec<Conv2d>,
resize_layers: Vec<Box<dyn Module>>,
readout_projections: Vec<Sequential>,
scratch: Scratch,
use_class_token: bool,
input_image_size: usize,
target_patch_size: usize,
}
impl<'a> DPTHead<'a> {
pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
impl DPTHead {
pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result<Self> {
let mut projections: Vec<Conv2d> = Vec::with_capacity(conf.out_channel_sizes.len());
for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() {
projections.push(conv2d(
@ -445,20 +449,22 @@ impl<'a> DPTHead<'a> {
let scratch = Scratch::new(conf, vb.pp("scratch"))?;
Ok(Self {
conf,
projections,
resize_layers,
readout_projections,
scratch,
use_class_token: conf.use_class_token,
input_image_size: conf.input_image_size,
target_patch_size: conf.target_patch_size,
})
}
}
impl Module for DPTHead<'_> {
impl Module for DPTHead {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut out: Vec<Tensor> = Vec::with_capacity(NUM_CHANNELS);
for i in 0..NUM_CHANNELS {
let x = if self.conf.use_class_token {
let x = if self.use_class_token {
let x = xs.get(i)?.get(0)?;
let class_token = xs.get(i)?.get(1)?;
let readout = class_token.unsqueeze(1)?.expand(x.shape())?;
@ -473,8 +479,8 @@ impl Module for DPTHead<'_> {
let x = x.permute((0, 2, 1))?.reshape((
x_dims[0],
x_dims[x_dims.len() - 1],
self.conf.target_patch_size,
self.conf.target_patch_size,
self.target_patch_size,
self.target_patch_size,
))?;
let x = self.projections[i].forward(&x)?;
@ -515,25 +521,25 @@ impl Module for DPTHead<'_> {
let out = self.scratch.output_conv1.forward(&path1)?;
let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?;
let out = out.interpolate2d(self.input_image_size, self.input_image_size)?;
self.scratch.output_conv2.forward(&out)
}
}
pub struct DepthAnythingV2<'a> {
pretrained: &'a DinoVisionTransformer,
depth_head: DPTHead<'a>,
conf: &'a DepthAnythingV2Config,
pub struct DepthAnythingV2 {
pretrained: Arc<DinoVisionTransformer>,
depth_head: DPTHead,
conf: DepthAnythingV2Config,
}
impl<'a> DepthAnythingV2<'a> {
impl DepthAnythingV2 {
pub fn new(
pretrained: &'a DinoVisionTransformer,
conf: &'a DepthAnythingV2Config,
pretrained: Arc<DinoVisionTransformer>,
conf: DepthAnythingV2Config,
vb: VarBuilder,
) -> Result<Self> {
let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?;
let depth_head = DPTHead::new(&conf, vb.pp("depth_head"))?;
Ok(Self {
pretrained,
@ -543,7 +549,7 @@ impl<'a> DepthAnythingV2<'a> {
}
}
impl Module for DepthAnythingV2<'_> {
impl Module for DepthAnythingV2 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let features = self.pretrained.get_intermediate_layers(
xs,