diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index 9a255511..6bef4dae 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -85,7 +85,7 @@ impl LayerScale { impl Module for LayerScale { fn forward(&self, xs: &Tensor) -> Result { - xs * &self.gamma + xs.broadcast_mul(&self.gamma) } } @@ -306,10 +306,17 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; + // TODO: apply imagenet normalization. + let image = candle_examples::load_image(args.image)?; + println!("loaded image {image:?}"); + let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - let _model = vit_small(vb)?; + let model = vit_small(vb)?; println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)?; + println!("{prs}"); Ok(()) } diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 6edd8ae6..10bbd309 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -332,7 +332,7 @@ fn run(args: Args) -> Result<()> { let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); - crate::utils::save_image(&image, image_filename)? + candle_examples::save_image(&image, image_filename)? } } @@ -346,7 +346,7 @@ fn run(args: Args) -> Result<()> { let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, None); - crate::utils::save_image(&image, image_filename)? + candle_examples::save_image(&image, image_filename)? } Ok(()) } diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 5602a9ad..c62f17af 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -12,25 +12,6 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result { Tensor::from_vec(vs, steps, &Device::Cpu) } -/// Saves an image to disk using the image crate, this expects an input with shape -/// (c, width, height). -pub fn save_image>(img: &Tensor, p: P) -> Result<()> { - let p = p.as_ref(); - let (channel, width, height) = img.dims3()?; - if channel != 3 { - candle::bail!("save_image expects an input of shape (3, width, height)") - } - let img = img.transpose(0, 1)?.t()?.flatten_all()?; - let pixels = img.to_vec1::()?; - let image: image::ImageBuffer, Vec> = - match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { - Some(image) => image, - None => candle::bail!("error saving image {p:?}"), - }; - image.save(p).map_err(candle::Error::wrap)?; - Ok(()) -} - // Wrap the conv2d op to provide some tracing. #[derive(Debug)] pub struct Conv2d { diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 2b6009b4..93da0240 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -1,4 +1,4 @@ -use candle::{Device, Result}; +use candle::{Device, Result, Tensor}; pub fn device(cpu: bool) -> Result { if cpu { @@ -12,6 +12,42 @@ pub fn device(cpu: bool) -> Result { } } +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 224, 224). imagenet normaliation is applied. +pub fn load_image>(p: P) -> Result { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill(224, 224, image::imageops::FilterType::Triangle); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (3, 224, 224), &Device::Cpu)?; + let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + +/// Saves an image to disk using the image crate, this expects an input with shape +/// (c, width, height). +pub fn save_image>(img: &Tensor, p: P) -> Result<()> { + let p = p.as_ref(); + let (channel, width, height) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, width, height)") + } + let img = img.transpose(0, 1)?.t()?.flatten_all()?; + let pixels = img.to_vec1::()?; + let image: image::ImageBuffer, Vec> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} + #[cfg(test)] mod tests { // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856