mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Segment-anything fixes: avoid normalizing twice. (#767)
* Segment-anything fixes: avoid normalizing twice. * More fixes for the image aspect ratio.
This commit is contained in:
@ -16,6 +16,34 @@ pub fn device(cpu: bool) -> Result<Device> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
resize_longest: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?;
|
||||
let img = match resize_longest {
|
||||
None => img,
|
||||
Some(resize_longest) => {
|
||||
let (height, width) = (img.height(), img.width());
|
||||
let resize_longest = resize_longest as u32;
|
||||
let (height, width) = if height < width {
|
||||
let h = (resize_longest * height) / width;
|
||||
(h, resize_longest)
|
||||
} else {
|
||||
let w = (resize_longest * width) / height;
|
||||
(resize_longest, w)
|
||||
};
|
||||
img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
|
||||
}
|
||||
};
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
width: usize,
|
||||
|
Reference in New Issue
Block a user