Depth Anything v2 (#2279)

* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Jeroen Vlek
2024-06-24 19:12:52 +02:00
committed by GitHub
parent 6baa1d486b
commit 242e006bbb
8 changed files with 911 additions and 1 deletions

View File

@ -258,6 +258,84 @@ impl DinoVisionTransformer {
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
fn get_intermediate_layers_not_chunked(
&self,
xs: &Tensor,
blocks_to_take: &[usize],
) -> Result<Vec<Tensor>> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
let mut output = Vec::new();
for (i, blk) in self.blocks.iter().enumerate() {
xs = blk.forward(&xs)?;
if blocks_to_take.contains(&i) {
output.push(xs.clone());
}
}
if output.len() != blocks_to_take.len() {
candle::bail!(
"only {} / {} blocks found",
output.len(),
blocks_to_take.len()
);
}
Ok(output)
}
pub fn get_intermediate_layers(
&self,
xs: &Tensor,
blocks_to_take: &[usize],
reshape: bool,
return_class_token: bool,
norm: bool,
) -> Result<Tensor> {
let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?;
let outputs = if norm {
outputs
.iter()
.map(|out| self.norm.forward(out))
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
let class_tokens = outputs
.iter()
.map(|out| out.i((.., 0)))
.collect::<Result<Vec<_>>>()?;
let outputs = outputs
.iter()
.map(|out| out.i((.., 1..)))
.collect::<Result<Vec<_>>>()?;
let outputs = if reshape {
let (b, _c, w, h) = xs.dims4()?;
let patch_size = self.patch_embed.patch_size.0;
let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size));
outputs
.iter()
.map(|out| {
out.reshape((b, w / patch_size, h / patch_size, num_channels))?
.transpose(2, 3)?
.transpose(1, 2)
})
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
let outputs = if return_class_token {
outputs
.iter()
.zip(class_tokens.iter())
.map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1))
.collect::<Result<Vec<_>>>()?
} else {
outputs
};
Tensor::stack(&outputs[..], 0)
}
}
impl Module for DinoVisionTransformer {