diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md
index ddbbc7af..d68a917e 100644
--- a/candle-book/src/training/README.md
+++ b/candle-book/src/training/README.md
@@ -6,19 +6,19 @@ start with the Hello world dataset of machine learning, MNIST.
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
-This requires `candle-datasets` with the `hub` feature.
+This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
```bash
-cargo add candle-datasets --features hub
cargo add hf-hub
```
+This is going to be very hands-on for now.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
```
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
-`files` is now a `Vec` of [`parquet::file::serialized_reader::SerializedFileReader`].
+Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
We can inspect the content of the files with:
@@ -37,5 +37,3 @@ Column id 0, name image, value {bytes: [137, ....]
So each row contains 2 columns (image, label) with image being saved as bytes.
Let's put them into a useful struct.
-
-
diff --git a/candle-book/src/training/mnist.md b/candle-book/src/training/mnist.md
index 642960a4..8ec00c1c 100644
--- a/candle-book/src/training/mnist.md
+++ b/candle-book/src/training/mnist.md
@@ -1 +1,10 @@
# MNIST
+
+So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
+
+```rust,ignore
+{{#include ../../../candle-examples/src/lib.rs:book_training_3}}
+```
+
+The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
+It is quite rudimentary, but simple enough for a small dataset like MNIST.
diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml
index 7c4c0aa5..4722c76c 100644
--- a/candle-datasets/Cargo.toml
+++ b/candle-datasets/Cargo.toml
@@ -19,7 +19,5 @@ memmap2 = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true }
thiserror = { workspace = true }
-parquet = { version = "45.0.0", optional = true }
-
-[features]
-hub = ["dep:parquet"]
+parquet = { version = "45.0.0"}
+image = { workspace = true }
diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index 2267f9a0..c908412c 100644
--- a/candle-datasets/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
@@ -3,6 +3,8 @@
//! The files can be obtained from the following link:
//!
use candle::{DType, Device, Result, Tensor};
+use hf_hub::{api::sync::Api, Repo, RepoType};
+use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{self, BufReader, Read};
@@ -63,3 +65,54 @@ pub fn load_dir>(dir: T) -> Result) -> Result<(Tensor, Tensor)> {
+ let samples = parquet.metadata().file_metadata().num_rows() as usize;
+ let mut buffer_images: Vec = Vec::with_capacity(samples * 784);
+ let mut buffer_labels: Vec = Vec::with_capacity(samples);
+ for row in parquet.into_iter().flatten() {
+ for (_name, field) in row.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ } else if let parquet::record::Field::Long(label) = field {
+ buffer_labels.push(*label as u8);
+ }
+ }
+ }
+ let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
+ .to_dtype(DType::F32)?
+ / 255.)?;
+ let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
+ Ok((images, labels))
+}
+
+pub fn load() -> Result {
+ let api = Api::new().unwrap();
+ let dataset_id = "mnist".to_string();
+ let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+ );
+ let repo = api.repo(repo);
+ let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap();
+ let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap();
+ let test_parquet =
+ SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap();
+ let train_parquet =
+ SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap();
+ let (test_images, test_labels) = load_parquet(test_parquet)?;
+ let (train_images, train_labels) = load_parquet(train_parquet)?;
+ Ok(crate::vision::Dataset {
+ train_images,
+ train_labels,
+ test_images,
+ test_labels,
+ labels: 10,
+ })
+}
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 39d2bc72..cc1f1e4c 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -41,6 +41,8 @@ tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
+parquet = "45.0.0"
+image.workspace = true
[build-dependencies]
anyhow = { workspace = true }
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index bcf8677d..2cdad9fa 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -143,7 +143,7 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
- let m = candle_datasets::vision::mnist::load_dir("data")?;
+ let m = candle_datasets::vision::mnist::load()?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 3fdd4cc9..0b716e4f 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -56,6 +56,8 @@ pub fn save_image>(img: &Tensor, p: P) -> Result<()> {
#[cfg(test)]
mod tests {
use anyhow::Result;
+ use candle::{DType, Device, Tensor};
+ use parquet::file::reader::SerializedFileReader;
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
@@ -157,20 +159,90 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
-use candle_datasets::hub::from_hub;
-use hf_hub::api::sync::Api;
+use hf_hub::{api::sync::Api, Repo, RepoType};
+
+let dataset_id = "mnist".to_string();
let api = Api::new()?;
-let files = from_hub(&api, "mnist".to_string())?;
+let repo = Repo::with_revision(
+ dataset_id,
+ RepoType::Dataset,
+ "refs/convert/parquet".to_string(),
+);
+let repo = api.repo(repo);
+let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
+let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR_END: book_training_1
+// Ignore unused
+let _train = train_parquet;
// ANCHOR: book_training_2
-let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten();
-for row in rows {
- for (idx, (name, field)) in row.get_column_iter().enumerate() {
+for row in test_parquet {
+ for (idx, (name, field)) in row?.get_column_iter().enumerate() {
println!("Column id {idx}, name {name}, value {field}");
}
}
// ANCHOR_END: book_training_2
+let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
+let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
+let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
+let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
+// ANCHOR: book_training_3
+
+let test_samples = 10_000;
+let mut test_buffer_images: Vec = Vec::with_capacity(test_samples * 784);
+let mut test_buffer_labels: Vec = Vec::with_capacity(test_samples);
+for row in test_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ test_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ test_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
+
+let train_samples = 60_000;
+let mut train_buffer_images: Vec = Vec::with_capacity(train_samples * 784);
+let mut train_buffer_labels: Vec = Vec::with_capacity(train_samples);
+for row in train_parquet{
+ for (_name, field) in row?.get_column_iter() {
+ if let parquet::record::Field::Group(subrow) = field {
+ for (_name, field) in subrow.get_column_iter() {
+ if let parquet::record::Field::Bytes(value) = field {
+ let image = image::load_from_memory(value.data()).unwrap();
+ train_buffer_images.extend(image.to_luma8().as_raw());
+ }
+ }
+ }else if let parquet::record::Field::Long(label) = field {
+ train_buffer_labels.push(*label as u8);
+ }
+ }
+}
+let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
+let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
+
+let mnist = candle_datasets::vision::Dataset {
+ train_images,
+ train_labels,
+ test_images,
+ test_labels,
+ labels: 10,
+};
+
+// ANCHOR_END: book_training_3
+assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
+assert_eq!(mnist.test_labels.dims(), &[10_000]);
+assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
+assert_eq!(mnist.train_labels.dims(), &[60_000]);
Ok(())
}
}