Fix CIFAR10 dataset types and dimension ordering (#2845)

This commit is contained in:
Bryan Lee
2025-03-30 04:53:25 -04:00
committed by GitHub
parent cb02b389d5
commit 59c26195db

View File

@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
// image-rs crate convention is to load in (width, height, channels) order
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
.to_dtype(DType::F32)?
.permute((0, 3, 2, 1))?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))