mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch). * Add the backprop for dimension permutation.
This commit is contained in:
@ -306,8 +306,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
// TODO: apply imagenet normalization.
|
||||
let image = candle_examples::load_image(args.image)?;
|
||||
let image = candle_examples::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
|
@ -13,8 +13,8 @@ pub fn device(cpu: bool) -> Result<Device> {
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 224, 224). imagenet normaliation is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
/// (3, 224, 224). imagenet normalization is applied.
|
||||
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
|
Reference in New Issue
Block a user