Training:

- Removed a lot of surface (SerializedFileReader ownership is really
  painful).
- Moved example + vision to hf.co version.
- Removed feature gate.
This commit is contained in:
Nicolas Patry
2023-08-14 17:23:08 +02:00
parent dd02f589c0
commit d7a273be51
7 changed files with 148 additions and 16 deletions

View File

@ -6,19 +6,19 @@ start with the Hello world dataset of machine learning, MNIST.
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
This requires `candle-datasets` with the `hub` feature.
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
```bash
cargo add candle-datasets --features hub
cargo add hf-hub
```
This is going to be very hands-on for now.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
```
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
`files` is now a `Vec` of [`parquet::file::serialized_reader::SerializedFileReader`].
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
We can inspect the content of the files with:
@ -37,5 +37,3 @@ Column id 0, name image, value {bytes: [137, ....]
So each row contains 2 columns (image, label) with image being saved as bytes.
Let's put them into a useful struct.

View File

@ -1 +1,10 @@
# MNIST
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_3}}
```
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
It is quite rudimentary, but simple enough for a small dataset like MNIST.

View File

@ -19,7 +19,5 @@ memmap2 = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true }
thiserror = { workspace = true }
parquet = { version = "45.0.0", optional = true }
[features]
hub = ["dep:parquet"]
parquet = { version = "45.0.0"}
image = { workspace = true }

View File

@ -3,6 +3,8 @@
//! The files can be obtained from the following link:
//! <http://yann.lecun.com/exdb/mnist/>
use candle::{DType, Device, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{self, BufReader, Read};
@ -63,3 +65,54 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
labels: 10,
})
}
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_luma8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
.to_dtype(DType::F32)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<crate::vision::Dataset> {
let api = Api::new().unwrap();
let dataset_id = "mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap();
let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap();
let test_parquet =
SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap();
let train_parquet =
SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap();
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -41,6 +41,8 @@ tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
parquet = "45.0.0"
image.workspace = true
[build-dependencies]
anyhow = { workspace = true }

View File

@ -143,7 +143,7 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_datasets::vision::mnist::load_dir("data")?;
let m = candle_datasets::vision::mnist::load()?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());

View File

@ -56,6 +56,8 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
#[cfg(test)]
mod tests {
use anyhow::Result;
use candle::{DType, Device, Tensor};
use parquet::file::reader::SerializedFileReader;
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
@ -157,20 +159,90 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use candle_datasets::hub::from_hub;
use hf_hub::api::sync::Api;
use hf_hub::{api::sync::Api, Repo, RepoType};
let dataset_id = "mnist".to_string();
let api = Api::new()?;
let files = from_hub(&api, "mnist".to_string())?;
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR_END: book_training_1
// Ignore unused
let _train = train_parquet;
// ANCHOR: book_training_2
let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten();
for row in rows {
for (idx, (name, field)) in row.get_column_iter().enumerate() {
for row in test_parquet {
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
println!("Column id {idx}, name {name}, value {field}");
}
}
// ANCHOR_END: book_training_2
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR: book_training_3
let test_samples = 10_000;
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
for row in test_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
test_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
test_buffer_labels.push(*label as u8);
}
}
}
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
let train_samples = 60_000;
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
for row in train_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
train_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
train_buffer_labels.push(*label as u8);
}
}
}
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
let mnist = candle_datasets::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
};
// ANCHOR_END: book_training_3
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
assert_eq!(mnist.test_labels.dims(), &[10_000]);
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
assert_eq!(mnist.train_labels.dims(), &[60_000]);
Ok(())
}
}