Add the permute op (similar to pytorch). (#504)

* Add the permute op (similar to pytorch).

* Add the backprop for dimension permutation.
This commit is contained in:
Laurent Mazare
2023-08-18 16:30:53 +01:00
committed by GitHub
parent 4f1541526c
commit cb069d6063
7 changed files with 85 additions and 4 deletions

View File

@ -306,8 +306,7 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
// TODO: apply imagenet normalization.
let image = candle_examples::load_image(args.image)?;
let image = candle_examples::load_image224(args.image)?;
println!("loaded image {image:?}");
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };

View File

@ -13,8 +13,8 @@ pub fn device(cpu: bool) -> Result<Device> {
}
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 224, 224). imagenet normaliation is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
/// (3, 224, 224). imagenet normalization is applied.
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?