diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md index f4f9eb85..ddbbc7af 100644 --- a/candle-book/src/training/README.md +++ b/candle-book/src/training/README.md @@ -6,12 +6,36 @@ start with the Hello world dataset of machine learning, MNIST. Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist). - -```rust -use candle_datasets::from_hub; +This requires `candle-datasets` with the `hub` feature. +```bash +cargo add candle-datasets --features hub +cargo add hf-hub +``` -let dataset = from_hub("mnist")?; +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_training_1}} ``` This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset. +`files` is now a `Vec` of [`parquet::file::serialized_reader::SerializedFileReader`]. + +We can inspect the content of the files with: + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_training_2}} +``` + +You should see something like: + +```bash +Column id 1, name label, value 6 +Column id 0, name image, value {bytes: [137, ....] +Column id 1, name label, value 8 +Column id 0, name image, value {bytes: [137, ....] +``` + +So each row contains 2 columns (image, label) with image being saved as bytes. +Let's put them into a useful struct. + + diff --git a/candle-datasets/src/hub.rs b/candle-datasets/src/hub.rs index 53fa8156..c00a788c 100644 --- a/candle-datasets/src/hub.rs +++ b/candle-datasets/src/hub.rs @@ -22,7 +22,7 @@ fn sibling_to_parquet( repo: &ApiRepo, ) -> Result, Error> { let local = repo.get(rfilename)?; - let file = File::open(&local)?; + let file = File::open(local)?; let reader = SerializedFileReader::new(file)?; Ok(reader) } diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 8bf94eb7..3fdd4cc9 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -55,6 +55,8 @@ pub fn save_image>(img: &Tensor, p: P) -> Result<()> { #[cfg(test)] mod tests { + use anyhow::Result; + // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 #[rustfmt::skip] #[tokio::test] @@ -150,4 +152,25 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un assert_eq!(view.shape(), &[768, 768]); assert_eq!(tp_tensor.dims(), &[192, 768]); } + + #[rustfmt::skip] + #[test] + fn book_training_1() -> Result<()>{ +// ANCHOR: book_training_1 +use candle_datasets::hub::from_hub; +use hf_hub::api::sync::Api; + +let api = Api::new()?; +let files = from_hub(&api, "mnist".to_string())?; +// ANCHOR_END: book_training_1 +// ANCHOR: book_training_2 +let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten(); +for row in rows { + for (idx, (name, field)) in row.get_column_iter().enumerate() { + println!("Column id {idx}, name {name}, value {field}"); + } +} +// ANCHOR_END: book_training_2 +Ok(()) + } }