From 59c26195db7e6ccb9ec86d7922781bd48bccba79 Mon Sep 17 00:00:00 2001 From: Bryan Lee Date: Sun, 30 Mar 2025 04:53:25 -0400 Subject: [PATCH] Fix CIFAR10 dataset types and dimension ordering (#2845) --- candle-datasets/src/vision/cifar.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/candle-datasets/src/vision/cifar.rs b/candle-datasets/src/vision/cifar.rs index 4b403a2e..7c66aa11 100644 --- a/candle-datasets/src/vision/cifar.rs +++ b/candle-datasets/src/vision/cifar.rs @@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader) -> Result<(Tensor, if let parquet::record::Field::Group(subrow) = field { for (_name, field) in subrow.get_column_iter() { if let parquet::record::Field::Bytes(value) = field { + // image-rs crate convention is to load in (width, height, channels) order + // See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions let image = image::load_from_memory(value.data()).unwrap(); buffer_images.extend(image.to_rgb8().as_raw()); } @@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader) -> Result<(Tensor, } } } - let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)? - .to_dtype(DType::U8)? + // Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width) + let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)? + .to_dtype(DType::F32)? + .permute((0, 3, 2, 1))? / 255.)?; let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?; Ok((images, labels))