mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Segment-anything fixes: avoid normalizing twice. (#767)
* Segment-anything fixes: avoid normalizing twice. * More fixes for the image aspect ratio.
This commit is contained in:
@ -108,7 +108,8 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
let image =
|
||||||
|
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?;
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
@ -125,7 +126,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||||
|
|
||||||
let (mask, iou_predictions) = sam.forward(&image, false)?;
|
let (mask, iou_predictions) = sam.forward(&image, false)?;
|
||||||
println!("mask: {mask:?}");
|
println!("mask:\n{mask}");
|
||||||
println!("iou_predictions: {iou_predictions:?}");
|
println!("iou_predictions: {iou_predictions:?}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ use crate::model_mask_decoder::MaskDecoder;
|
|||||||
use crate::model_prompt_encoder::PromptEncoder;
|
use crate::model_prompt_encoder::PromptEncoder;
|
||||||
|
|
||||||
const PROMPT_EMBED_DIM: usize = 256;
|
const PROMPT_EMBED_DIM: usize = 256;
|
||||||
const IMAGE_SIZE: usize = 1024;
|
pub const IMAGE_SIZE: usize = 1024;
|
||||||
const VIT_PATCH_SIZE: usize = 16;
|
const VIT_PATCH_SIZE: usize = 16;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -90,6 +90,7 @@ impl Sam {
|
|||||||
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||||
let (c, h, w) = img.dims3()?;
|
let (c, h, w) = img.dims3()?;
|
||||||
let img = img
|
let img = img
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
.broadcast_sub(&self.pixel_mean)?
|
.broadcast_sub(&self.pixel_mean)?
|
||||||
.broadcast_div(&self.pixel_std)?;
|
.broadcast_div(&self.pixel_std)?;
|
||||||
if h > IMAGE_SIZE || w > IMAGE_SIZE {
|
if h > IMAGE_SIZE || w > IMAGE_SIZE {
|
||||||
|
@ -16,6 +16,34 @@ pub fn device(cpu: bool) -> Result<Device> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||||
|
p: P,
|
||||||
|
resize_longest: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let img = image::io::Reader::open(p)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
let img = match resize_longest {
|
||||||
|
None => img,
|
||||||
|
Some(resize_longest) => {
|
||||||
|
let (height, width) = (img.height(), img.width());
|
||||||
|
let resize_longest = resize_longest as u32;
|
||||||
|
let (height, width) = if height < width {
|
||||||
|
let h = (resize_longest * height) / width;
|
||||||
|
(h, resize_longest)
|
||||||
|
} else {
|
||||||
|
let w = (resize_longest * width) / height;
|
||||||
|
(resize_longest, w)
|
||||||
|
};
|
||||||
|
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||||
|
let img = img.to_rgb8();
|
||||||
|
let data = img.into_raw();
|
||||||
|
Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||||
p: P,
|
p: P,
|
||||||
width: usize,
|
width: usize,
|
||||||
|
Reference in New Issue
Block a user