mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
dinov2 - read images from disk and compute the class probabilities (#503)
* Load the image from disk and convert it to a tensor. * Tweak the function name.
This commit is contained in:
@ -85,7 +85,7 @@ impl LayerScale {
|
|||||||
|
|
||||||
impl Module for LayerScale {
|
impl Module for LayerScale {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
xs * &self.gamma
|
xs.broadcast_mul(&self.gamma)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,10 +306,17 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
// TODO: apply imagenet normalization.
|
||||||
|
let image = candle_examples::load_image(args.image)?;
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
let _model = vit_small(vb)?;
|
let model = vit_small(vb)?;
|
||||||
println!("model built");
|
println!("model built");
|
||||||
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
||||||
|
println!("{prs}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -332,7 +332,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||||
let image_filename =
|
let image_filename =
|
||||||
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
||||||
crate::utils::save_image(&image, image_filename)?
|
candle_examples::save_image(&image, image_filename)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -346,7 +346,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||||
crate::utils::save_image(&image, image_filename)?
|
candle_examples::save_image(&image, image_filename)?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,25 +12,6 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
|||||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
|
||||||
/// (c, width, height).
|
|
||||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
|
||||||
let p = p.as_ref();
|
|
||||||
let (channel, width, height) = img.dims3()?;
|
|
||||||
if channel != 3 {
|
|
||||||
candle::bail!("save_image expects an input of shape (3, width, height)")
|
|
||||||
}
|
|
||||||
let img = img.transpose(0, 1)?.t()?.flatten_all()?;
|
|
||||||
let pixels = img.to_vec1::<u8>()?;
|
|
||||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
|
||||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
|
||||||
Some(image) => image,
|
|
||||||
None => candle::bail!("error saving image {p:?}"),
|
|
||||||
};
|
|
||||||
image.save(p).map_err(candle::Error::wrap)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap the conv2d op to provide some tracing.
|
// Wrap the conv2d op to provide some tracing.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Conv2d {
|
pub struct Conv2d {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{Device, Result};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn device(cpu: bool) -> Result<Device> {
|
pub fn device(cpu: bool) -> Result<Device> {
|
||||||
if cpu {
|
if cpu {
|
||||||
@ -12,6 +12,42 @@ pub fn device(cpu: bool) -> Result<Device> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||||
|
/// (3, 224, 224). imagenet normaliation is applied.
|
||||||
|
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||||
|
let img = image::io::Reader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?
|
||||||
|
.resize_to_fill(224, 224, image::imageops::FilterType::Triangle);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
let data = Tensor::from_vec(data, (3, 224, 224), &Device::Cpu)?;
|
||||||
|
let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?;
|
||||||
|
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||||
|
.broadcast_sub(&mean)?
|
||||||
|
.broadcast_div(&std)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||||
|
/// (c, width, height).
|
||||||
|
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||||
|
let p = p.as_ref();
|
||||||
|
let (channel, width, height) = img.dims3()?;
|
||||||
|
if channel != 3 {
|
||||||
|
candle::bail!("save_image expects an input of shape (3, width, height)")
|
||||||
|
}
|
||||||
|
let img = img.transpose(0, 1)?.t()?.flatten_all()?;
|
||||||
|
let pixels = img.to_vec1::<u8>()?;
|
||||||
|
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||||
|
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||||
|
Some(image) => image,
|
||||||
|
None => candle::bail!("error saving image {p:?}"),
|
||||||
|
};
|
||||||
|
image.save(p).map_err(candle::Error::wrap)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||||
|
Reference in New Issue
Block a user