From 76023236677fbab10fd1c99eab95d268416fb941 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 14 Aug 2023 10:52:12 +0200 Subject: [PATCH] [Book] Add small error management + start training (with generic dataset inclusion). --- candle-book/src/SUMMARY.md | 10 ++-- candle-book/src/training/README.md | 16 +++++ candle-datasets/Cargo.toml | 5 ++ candle-datasets/src/hub.rs | 94 ++++++++++++++++++++++++++++++ candle-datasets/src/lib.rs | 2 + 5 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 candle-datasets/src/hub.rs diff --git a/candle-book/src/SUMMARY.md b/candle-book/src/SUMMARY.md index 3432f66f..6eadb0c1 100644 --- a/candle-book/src/SUMMARY.md +++ b/candle-book/src/SUMMARY.md @@ -12,7 +12,11 @@ - [Running a model](inference/README.md) - [Using the hub](inference/hub.md) -- [Error management]() +- [Error management](error_manage.md) +- [Training](training/README.md) + - [MNIST]() + - [Fine-tuning]() + - [Serialization]() - [Advanced Cuda usage]() - [Writing a custom kernel]() - [Porting a custom kernel]() @@ -21,7 +25,3 @@ - [Creating a WASM app]() - [Creating a REST api webserver]() - [Creating a desktop Tauri app]() -- [Training]() - - [MNIST]() - - [Fine-tuning]() - - [Serialization]() diff --git a/candle-book/src/training/README.md b/candle-book/src/training/README.md index 8977de34..f4f9eb85 100644 --- a/candle-book/src/training/README.md +++ b/candle-book/src/training/README.md @@ -1 +1,17 @@ # Training + + +Training starts with data. We're going to use the huggingface hub and +start with the Hello world dataset of machine learning, MNIST. + +Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist). + + +```rust +use candle_datasets::from_hub; + + +let dataset = from_hub("mnist")?; +``` + +This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset. diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 91fced54..7c4c0aa5 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -18,3 +18,8 @@ intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } rand = { workspace = true } +thiserror = { workspace = true } +parquet = { version = "45.0.0", optional = true } + +[features] +hub = ["dep:parquet"] diff --git a/candle-datasets/src/hub.rs b/candle-datasets/src/hub.rs new file mode 100644 index 00000000..53fa8156 --- /dev/null +++ b/candle-datasets/src/hub.rs @@ -0,0 +1,94 @@ +use hf_hub::{ + api::sync::{Api, ApiRepo}, + Repo, RepoType, +}; +use parquet::file::reader::SerializedFileReader; +use std::fs::File; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("ApiError : {0}")] + ApiError(#[from] hf_hub::api::sync::ApiError), + + #[error("IoError : {0}")] + IoError(#[from] std::io::Error), + + #[error("ParquetError : {0}")] + ParquetError(#[from] parquet::errors::ParquetError), +} + +fn sibling_to_parquet( + rfilename: &str, + repo: &ApiRepo, +) -> Result, Error> { + let local = repo.get(rfilename)?; + let file = File::open(&local)?; + let reader = SerializedFileReader::new(file)?; + Ok(reader) +} + +pub fn from_hub(api: &Api, dataset_id: String) -> Result>, Error> { + let repo = Repo::with_revision( + dataset_id, + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let repo = api.repo(repo); + let info = repo.info()?; + + let files: Result, _> = info + .siblings + .into_iter() + .filter_map(|s| -> Option> { + let filename = s.rfilename; + if filename.ends_with(".parquet") { + let reader_result = sibling_to_parquet(&filename, &repo); + Some(reader_result) + } else { + None + } + }) + .collect(); + let files = files?; + + Ok(files) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dataset() { + let api = Api::new().unwrap(); + let files = from_hub( + &api, + "hf-internal-testing/dummy_image_text_data".to_string(), + ) + .unwrap(); + assert_eq!(files.len(), 1); + + let mut rows = files.into_iter().flat_map(|r| r.into_iter()); + + let row = rows.next().unwrap().unwrap(); + let mut col_iter = row.get_column_iter(); + + // First element is an image + col_iter.next(); + assert_eq!( + col_iter.next().unwrap().1, + &parquet::record::Field::Str("a drawing of a green pokemon with red eyes".to_string()) + ); + + // Keep for now to showcase how to use. + //for row in rows { + // if let Ok(row) = row { + // for (_idx, (_name, field)) in row.get_column_iter().enumerate() { + // if let parquet::record::Field::Str(value) = field { + // println!("Value {value:?}"); + // } + // } + // } + //} + } +} diff --git a/candle-datasets/src/lib.rs b/candle-datasets/src/lib.rs index 42ad5d62..cc4304c2 100644 --- a/candle-datasets/src/lib.rs +++ b/candle-datasets/src/lib.rs @@ -1,5 +1,7 @@ //! Datasets & Dataloaders for Candle pub mod batcher; +#[cfg(feature = "hub")] +pub mod hub; pub mod nlp; pub mod vision;