mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Generate a mask image + the scaled input image. (#769)
* Also round-trip the original image. * Make it possible to use a safetensors input.
This commit is contained in:
@ -108,8 +108,20 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image =
|
let image = if args.image.ends_with(".safetensors") {
|
||||||
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?;
|
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
||||||
|
match tensors.remove("image") {
|
||||||
|
Some(image) => image,
|
||||||
|
None => {
|
||||||
|
if tensors.len() != 1 {
|
||||||
|
anyhow::bail!("multiple tensors in '{}'", args.image)
|
||||||
|
}
|
||||||
|
tensors.into_values().next().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
|
||||||
|
};
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
@ -128,5 +140,16 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let (mask, iou_predictions) = sam.forward(&image, false)?;
|
let (mask, iou_predictions) = sam.forward(&image, false)?;
|
||||||
println!("mask:\n{mask}");
|
println!("mask:\n{mask}");
|
||||||
println!("iou_predictions: {iou_predictions:?}");
|
println!("iou_predictions: {iou_predictions:?}");
|
||||||
|
|
||||||
|
// Save the mask as an image.
|
||||||
|
let mask = mask.ge(&mask.zeros_like()?)?;
|
||||||
|
let mask = (mask * 255.)?.squeeze(0)?;
|
||||||
|
let (_one, h, w) = mask.dims3()?;
|
||||||
|
let mask = mask.expand((3, h, w))?;
|
||||||
|
candle_examples::save_image(&mask, "sam_mask.png")?;
|
||||||
|
|
||||||
|
let image = sam.preprocess(&image)?;
|
||||||
|
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
|
||||||
|
candle_examples::save_image(&image, "sam_input_scaled.png")?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,15 @@ impl Sam {
|
|||||||
Ok((low_res_mask, iou_predictions))
|
Ok((low_res_mask, iou_predictions))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||||
|
let img = img
|
||||||
|
.broadcast_mul(&self.pixel_std)?
|
||||||
|
.broadcast_add(&self.pixel_mean)?;
|
||||||
|
img.maximum(&img.zeros_like()?)?
|
||||||
|
.minimum(&(img.ones_like()? * 255.)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||||
let (c, h, w) = img.dims3()?;
|
let (c, h, w) = img.dims3()?;
|
||||||
let img = img
|
let img = img
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
|
Reference in New Issue
Block a user