mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Training:
- Removed a lot of surface (SerializedFileReader ownership is really painful). - Moved example + vision to hf.co version. - Removed feature gate.
This commit is contained in:
@ -6,19 +6,19 @@ start with the Hello world dataset of machine learning, MNIST.
|
||||
|
||||
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
||||
|
||||
This requires `candle-datasets` with the `hub` feature.
|
||||
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
|
||||
```bash
|
||||
cargo add candle-datasets --features hub
|
||||
cargo add hf-hub
|
||||
```
|
||||
|
||||
This is going to be very hands-on for now.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
|
||||
```
|
||||
|
||||
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
||||
`files` is now a `Vec` of [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||
|
||||
We can inspect the content of the files with:
|
||||
|
||||
@ -37,5 +37,3 @@ Column id 0, name image, value {bytes: [137, ....]
|
||||
|
||||
So each row contains 2 columns (image, label) with image being saved as bytes.
|
||||
Let's put them into a useful struct.
|
||||
|
||||
|
||||
|
@ -1 +1,10 @@
|
||||
# MNIST
|
||||
|
||||
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_3}}
|
||||
```
|
||||
|
||||
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||
It is quite rudimentary, but simple enough for a small dataset like MNIST.
|
||||
|
@ -19,7 +19,5 @@ memmap2 = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
rand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
parquet = { version = "45.0.0", optional = true }
|
||||
|
||||
[features]
|
||||
hub = ["dep:parquet"]
|
||||
parquet = { version = "45.0.0"}
|
||||
image = { workspace = true }
|
||||
|
@ -3,6 +3,8 @@
|
||||
//! The files can be obtained from the following link:
|
||||
//! <http://yann.lecun.com/exdb/mnist/>
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, Read};
|
||||
|
||||
@ -63,3 +65,54 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
|
||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||
for row in parquet.into_iter().flatten() {
|
||||
for (_name, field) in row.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
} else if let parquet::record::Field::Long(label) = field {
|
||||
buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
}
|
||||
|
||||
pub fn load() -> Result<crate::vision::Dataset> {
|
||||
let api = Api::new().unwrap();
|
||||
let dataset_id = "mnist".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap();
|
||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap();
|
||||
let test_parquet =
|
||||
SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap();
|
||||
let train_parquet =
|
||||
SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap();
|
||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
@ -41,6 +41,8 @@ tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = "45.0.0"
|
||||
image.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
@ -143,7 +143,7 @@ struct Args {
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
// Load the dataset
|
||||
let m = candle_datasets::vision::mnist::load_dir("data")?;
|
||||
let m = candle_datasets::vision::mnist::load()?;
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
|
@ -56,6 +56,8 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
@ -157,20 +159,90 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use candle_datasets::hub::from_hub;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
let dataset_id = "mnist".to_string();
|
||||
|
||||
let api = Api::new()?;
|
||||
let files = from_hub(&api, "mnist".to_string())?;
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// Ignore unused
|
||||
let _train = train_parquet;
|
||||
// ANCHOR: book_training_2
|
||||
let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten();
|
||||
for row in rows {
|
||||
for (idx, (name, field)) in row.get_column_iter().enumerate() {
|
||||
for row in test_parquet {
|
||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR: book_training_3
|
||||
|
||||
let test_samples = 10_000;
|
||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||
for row in test_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
test_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||
|
||||
let train_samples = 60_000;
|
||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||
for row in train_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
train_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||
|
||||
let mnist = candle_datasets::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
};
|
||||
|
||||
// ANCHOR_END: book_training_3
|
||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user