mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Fix CIFAR10 dataset types and dimension ordering (#2845)
This commit is contained in:
@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
if let parquet::record::Field::Group(subrow) = field {
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
for (_name, field) in subrow.get_column_iter() {
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
if let parquet::record::Field::Bytes(value) = field {
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
// image-rs crate convention is to load in (width, height, channels) order
|
||||||
|
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
|
||||||
let image = image::load_from_memory(value.data()).unwrap();
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
buffer_images.extend(image.to_rgb8().as_raw());
|
buffer_images.extend(image.to_rgb8().as_raw());
|
||||||
}
|
}
|
||||||
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
|
||||||
.to_dtype(DType::U8)?
|
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.permute((0, 3, 2, 1))?
|
||||||
/ 255.)?;
|
/ 255.)?;
|
||||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||||
Ok((images, labels))
|
Ok((images, labels))
|
||||||
|
Reference in New Issue
Block a user