From d7a273be516e7795b6213e9f076676f76d0fac11 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 14 Aug 2023 17:23:08 +0200 Subject: [PATCH] Training: - Removed a lot of surface (SerializedFileReader ownership is really painful). - Moved example + vision to hf.co version. - Removed feature gate. --- candle-book/src/training/README.md | 8 +- candle-book/src/training/mnist.md | 9 ++ candle-datasets/Cargo.toml | 6 +- candle-datasets/src/vision/mnist.rs | 53 ++++++++++++ candle-examples/Cargo.toml | 2 + .../examples/mnist-training/main.rs | 2 +- candle-examples/src/lib.rs | 84 +++++++++++++++++-- 7 files changed, 148 insertions(+), 16 deletions(-) diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md index ddbbc7af..d68a917e 100644 --- a/candle-book/src/training/README.md +++ b/candle-book/src/training/README.md @@ -6,19 +6,19 @@ start with the Hello world dataset of machine learning, MNIST. Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist). -This requires `candle-datasets` with the `hub` feature. +This requires [`hf-hub`](https://github.com/huggingface/hf-hub). ```bash -cargo add candle-datasets --features hub cargo add hf-hub ``` +This is going to be very hands-on for now. ```rust,ignore {{#include ../../../candle-examples/src/lib.rs:book_training_1}} ``` This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset. -`files` is now a `Vec` of [`parquet::file::serialized_reader::SerializedFileReader`]. +Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`]. We can inspect the content of the files with: @@ -37,5 +37,3 @@ Column id 0, name image, value {bytes: [137, ....] So each row contains 2 columns (image, label) with image being saved as bytes. Let's put them into a useful struct. - - diff --git a/candle-book/src/training/mnist.md b/candle-book/src/training/mnist.md index 642960a4..8ec00c1c 100644 --- a/candle-book/src/training/mnist.md +++ b/candle-book/src/training/mnist.md @@ -1 +1,10 @@ # MNIST + +So we now have downloaded the MNIST parquet files, let's put them in a simple struct. + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_training_3}} +``` + +The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory. +It is quite rudimentary, but simple enough for a small dataset like MNIST. diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 7c4c0aa5..4722c76c 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -19,7 +19,5 @@ memmap2 = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } rand = { workspace = true } thiserror = { workspace = true } -parquet = { version = "45.0.0", optional = true } - -[features] -hub = ["dep:parquet"] +parquet = { version = "45.0.0"} +image = { workspace = true } diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs index 2267f9a0..c908412c 100644 --- a/candle-datasets/src/vision/mnist.rs +++ b/candle-datasets/src/vision/mnist.rs @@ -3,6 +3,8 @@ //! The files can be obtained from the following link: //! use candle::{DType, Device, Result, Tensor}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use parquet::file::reader::{FileReader, SerializedFileReader}; use std::fs::File; use std::io::{self, BufReader, Read}; @@ -63,3 +65,54 @@ pub fn load_dir>(dir: T) -> Result) -> Result<(Tensor, Tensor)> { + let samples = parquet.metadata().file_metadata().num_rows() as usize; + let mut buffer_images: Vec = Vec::with_capacity(samples * 784); + let mut buffer_labels: Vec = Vec::with_capacity(samples); + for row in parquet.into_iter().flatten() { + for (_name, field) in row.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + buffer_images.extend(image.to_luma8().as_raw()); + } + } + } else if let parquet::record::Field::Long(label) = field { + buffer_labels.push(*label as u8); + } + } + } + let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)? + .to_dtype(DType::F32)? + / 255.)?; + let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?; + Ok((images, labels)) +} + +pub fn load() -> Result { + let api = Api::new().unwrap(); + let dataset_id = "mnist".to_string(); + let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let repo = api.repo(repo); + let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap(); + let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap(); + let test_parquet = + SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap(); + let train_parquet = + SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap(); + let (test_images, test_labels) = load_parquet(test_parquet)?; + let (train_images, train_labels) = load_parquet(train_parquet)?; + Ok(crate::vision::Dataset { + train_images, + train_labels, + test_images, + test_labels, + labels: 10, + }) +} diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 39d2bc72..cc1f1e4c 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -41,6 +41,8 @@ tracing-subscriber = { workspace = true } wav = { workspace = true } # Necessary to disambiguate with tokio in wasm examples which are 1.28.1 tokio = "1.29.1" +parquet = "45.0.0" +image.workspace = true [build-dependencies] anyhow = { workspace = true } diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index bcf8677d..2cdad9fa 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -143,7 +143,7 @@ struct Args { pub fn main() -> anyhow::Result<()> { let args = Args::parse(); // Load the dataset - let m = candle_datasets::vision::mnist::load_dir("data")?; + let m = candle_datasets::vision::mnist::load()?; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); println!("test-images: {:?}", m.test_images.shape()); diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 3fdd4cc9..0b716e4f 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -56,6 +56,8 @@ pub fn save_image>(img: &Tensor, p: P) -> Result<()> { #[cfg(test)] mod tests { use anyhow::Result; + use candle::{DType, Device, Tensor}; + use parquet::file::reader::SerializedFileReader; // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 #[rustfmt::skip] @@ -157,20 +159,90 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un #[test] fn book_training_1() -> Result<()>{ // ANCHOR: book_training_1 -use candle_datasets::hub::from_hub; -use hf_hub::api::sync::Api; +use hf_hub::{api::sync::Api, Repo, RepoType}; + +let dataset_id = "mnist".to_string(); let api = Api::new()?; -let files = from_hub(&api, "mnist".to_string())?; +let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), +); +let repo = api.repo(repo); +let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?; +let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; // ANCHOR_END: book_training_1 +// Ignore unused +let _train = train_parquet; // ANCHOR: book_training_2 -let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten(); -for row in rows { - for (idx, (name, field)) in row.get_column_iter().enumerate() { +for row in test_parquet { + for (idx, (name, field)) in row?.get_column_iter().enumerate() { println!("Column id {idx}, name {name}, value {field}"); } } // ANCHOR_END: book_training_2 +let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?; +let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?; +let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?; +let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?; +// ANCHOR: book_training_3 + +let test_samples = 10_000; +let mut test_buffer_images: Vec = Vec::with_capacity(test_samples * 784); +let mut test_buffer_labels: Vec = Vec::with_capacity(test_samples); +for row in test_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + test_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + test_buffer_labels.push(*label as u8); + } + } +} +let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?; + +let train_samples = 60_000; +let mut train_buffer_images: Vec = Vec::with_capacity(train_samples * 784); +let mut train_buffer_labels: Vec = Vec::with_capacity(train_samples); +for row in train_parquet{ + for (_name, field) in row?.get_column_iter() { + if let parquet::record::Field::Group(subrow) = field { + for (_name, field) in subrow.get_column_iter() { + if let parquet::record::Field::Bytes(value) = field { + let image = image::load_from_memory(value.data()).unwrap(); + train_buffer_images.extend(image.to_luma8().as_raw()); + } + } + }else if let parquet::record::Field::Long(label) = field { + train_buffer_labels.push(*label as u8); + } + } +} +let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?; +let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?; + +let mnist = candle_datasets::vision::Dataset { + train_images, + train_labels, + test_images, + test_labels, + labels: 10, +}; + +// ANCHOR_END: book_training_3 +assert_eq!(mnist.test_images.dims(), &[10_000, 784]); +assert_eq!(mnist.test_labels.dims(), &[10_000]); +assert_eq!(mnist.train_images.dims(), &[60_000, 784]); +assert_eq!(mnist.train_labels.dims(), &[60_000]); Ok(()) } }