Segment-anything fixes: avoid normalizing twice. (#767)

* Segment-anything fixes: avoid normalizing twice.

* More fixes for the image aspect ratio.
This commit is contained in:
Laurent Mazare
2023-09-07 21:45:16 +01:00
committed by GitHub
parent 7396b8ed1a
commit 79c27fc489
3 changed files with 33 additions and 3 deletions

View File

@ -108,7 +108,8 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; let image =
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let model = match args.model { let model = match args.model {
@ -125,7 +126,7 @@ pub fn main() -> anyhow::Result<()> {
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
let (mask, iou_predictions) = sam.forward(&image, false)?; let (mask, iou_predictions) = sam.forward(&image, false)?;
println!("mask: {mask:?}"); println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}"); println!("iou_predictions: {iou_predictions:?}");
Ok(()) Ok(())
} }

View File

@ -6,7 +6,7 @@ use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder; use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256; const PROMPT_EMBED_DIM: usize = 256;
const IMAGE_SIZE: usize = 1024; pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16; const VIT_PATCH_SIZE: usize = 16;
#[derive(Debug)] #[derive(Debug)]
@ -90,6 +90,7 @@ impl Sam {
fn preprocess(&self, img: &Tensor) -> Result<Tensor> { fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
let (c, h, w) = img.dims3()?; let (c, h, w) = img.dims3()?;
let img = img let img = img
.to_dtype(DType::F32)?
.broadcast_sub(&self.pixel_mean)? .broadcast_sub(&self.pixel_mean)?
.broadcast_div(&self.pixel_std)?; .broadcast_div(&self.pixel_std)?;
if h > IMAGE_SIZE || w > IMAGE_SIZE { if h > IMAGE_SIZE || w > IMAGE_SIZE {

View File

@ -16,6 +16,34 @@ pub fn device(cpu: bool) -> Result<Device> {
} }
} }
pub fn load_image<P: AsRef<std::path::Path>>(
p: P,
resize_longest: Option<usize>,
) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?;
let img = match resize_longest {
None => img,
Some(resize_longest) => {
let (height, width) = (img.height(), img.width());
let resize_longest = resize_longest as u32;
let (height, width) = if height < width {
let h = (resize_longest * height) / width;
(h, resize_longest)
} else {
let w = (resize_longest * width) / height;
(resize_longest, w)
};
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
}
};
let (height, width) = (img.height() as usize, img.width() as usize);
let img = img.to_rgb8();
let data = img.into_raw();
Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
}
pub fn load_image_and_resize<P: AsRef<std::path::Path>>( pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
p: P, p: P,
width: usize, width: usize,