Merge pull request #439 from huggingface/training_hub_dataset

[Book] Add small error management + start training (with generic dataset inclusion).
This commit is contained in:
Nicolas Patry
2023-08-29 13:10:05 +02:00
committed by GitHub
14 changed files with 444 additions and 110 deletions

View File

@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable - name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- run: apt update -y && apt install libssl-dev -y - run: apt-get update -y && apt-get install libssl-dev -y
- name: Test (cuda) - name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner: stop-runner:

View File

@ -3,6 +3,7 @@ members = [
"candle-core", "candle-core",
"candle-datasets", "candle-datasets",
"candle-examples", "candle-examples",
"candle-book",
"candle-nn", "candle-nn",
"candle-pyo3", "candle-pyo3",
"candle-transformers", "candle-transformers",
@ -57,6 +58,7 @@ tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7" tracing-subscriber = "0.3.7"
wav = "1.0.0" wav = "1.0.0"
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
parquet = { version = "45.0.0" }
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

49
candle-book/Cargo.toml Normal file
View File

@ -0,0 +1,49 @@
[package]
name = "candle-book"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.0" }
candle-nn = { path = "../candle-nn", version = "0.2.0" }
candle-transformers = { path = "../candle-transformers", version = "0.2.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.0", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
parquet = { workspace = true }
image = { workspace = true }
[build-dependencies]
anyhow = { workspace = true }
[features]
default = []

View File

@ -12,7 +12,11 @@
- [Running a model](inference/README.md) - [Running a model](inference/README.md)
- [Using the hub](inference/hub.md) - [Using the hub](inference/hub.md)
- [Error management]() - [Error management](error_manage.md)
- [Training](training/README.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
- [Serialization]()
- [Advanced Cuda usage]() - [Advanced Cuda usage]()
- [Writing a custom kernel]() - [Writing a custom kernel]()
- [Porting a custom kernel]() - [Porting a custom kernel]()
@ -21,7 +25,3 @@
- [Creating a WASM app]() - [Creating a WASM app]()
- [Creating a REST api webserver]() - [Creating a REST api webserver]()
- [Creating a desktop Tauri app]() - [Creating a desktop Tauri app]()
- [Training]()
- [MNIST]()
- [Fine-tuning]()
- [Serialization]()

View File

@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
```rust,ignore ```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately: # This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706) # See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}} {{#include ../lib.rs:book_hub_1}}
``` ```
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
and will definitely be slower on network mounted disk, because it will issue more read calls. and will definitely be slower on network mounted disk, because it will issue more read calls.
```rust,ignore ```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}} {{#include ../lib.rs:book_hub_2}}
``` ```
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety). **Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
@ -100,5 +100,5 @@ cargo add safetensors
```rust,ignore ```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}} {{#include ../lib.rs:book_hub_3}}
``` ```

193
candle-book/src/lib.rs Normal file
View File

@ -0,0 +1,193 @@
#[cfg(test)]
mod tests {
use anyhow::Result;
use candle::{DType, Device, Tensor};
use parquet::file::reader::SerializedFileReader;
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_2() {
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_3() {
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
#[rustfmt::skip]
#[test]
fn book_training_1() -> Result<()>{
// ANCHOR: book_training_1
use hf_hub::{api::sync::Api, Repo, RepoType};
let dataset_id = "mnist".to_string();
let api = Api::new()?;
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR_END: book_training_1
// Ignore unused
let _train = train_parquet;
// ANCHOR: book_training_2
for row in test_parquet {
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
println!("Column id {idx}, name {name}, value {field}");
}
}
// ANCHOR_END: book_training_2
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
// ANCHOR: book_training_3
let test_samples = 10_000;
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
for row in test_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
test_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
test_buffer_labels.push(*label as u8);
}
}
}
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
let train_samples = 60_000;
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
for row in train_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
train_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
train_buffer_labels.push(*label as u8);
}
}
}
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
let mnist = candle_datasets::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
};
// ANCHOR_END: book_training_3
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
assert_eq!(mnist.test_labels.dims(), &[10_000]);
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
assert_eq!(mnist.train_labels.dims(), &[60_000]);
Ok(())
}
}

View File

@ -1 +1,39 @@
# Training # Training
Training starts with data. We're going to use the huggingface hub and
start with the Hello world dataset of machine learning, MNIST.
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
```bash
cargo add hf-hub
```
This is going to be very hands-on for now.
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
```
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
We can inspect the content of the files with:
```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
```
You should see something like:
```bash
Column id 1, name label, value 6
Column id 0, name image, value {bytes: [137, ....]
Column id 1, name label, value 8
Column id 0, name image, value {bytes: [137, ....]
```
So each row contains 2 columns (image, label) with image being saved as bytes.
Let's put them into a useful struct.

View File

@ -1 +1,10 @@
# MNIST # MNIST
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
```rust,ignore
{{#include ../lib.rs:book_training_3}}
```
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
It is quite rudimentary, but simple enough for a small dataset like MNIST.

View File

@ -18,3 +18,6 @@ intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] } tokenizers = { workspace = true, features = ["onig"] }
rand = { workspace = true } rand = { workspace = true }
thiserror = { workspace = true }
parquet = { workspace = true}
image = { workspace = true }

View File

@ -0,0 +1,73 @@
use hf_hub::{
api::sync::{Api, ApiRepo},
Repo, RepoType,
};
use parquet::file::reader::SerializedFileReader;
use std::fs::File;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("ApiError : {0}")]
ApiError(#[from] hf_hub::api::sync::ApiError),
#[error("IoError : {0}")]
IoError(#[from] std::io::Error),
#[error("ParquetError : {0}")]
ParquetError(#[from] parquet::errors::ParquetError),
}
fn sibling_to_parquet(
rfilename: &str,
repo: &ApiRepo,
) -> Result<SerializedFileReader<File>, Error> {
let local = repo.get(rfilename)?;
let file = File::open(local)?;
let reader = SerializedFileReader::new(file)?;
Ok(reader)
}
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let info = repo.info()?;
let files: Result<Vec<_>, _> = info
.siblings
.into_iter()
.filter_map(|s| -> Option<Result<_, _>> {
let filename = s.rfilename;
if filename.ends_with(".parquet") {
let reader_result = sibling_to_parquet(&filename, &repo);
Some(reader_result)
} else {
None
}
})
.collect();
let files = files?;
Ok(files)
}
#[cfg(test)]
mod tests {
use super::*;
use parquet::file::reader::FileReader;
#[test]
fn test_dataset() {
let api = Api::new().unwrap();
let files = from_hub(
&api,
"hf-internal-testing/dummy_image_text_data".to_string(),
)
.unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
}
}

View File

@ -1,5 +1,6 @@
//! Datasets & Dataloaders for Candle //! Datasets & Dataloaders for Candle
pub mod batcher; pub mod batcher;
pub mod hub;
pub mod nlp; pub mod nlp;
pub mod vision; pub mod vision;

View File

@ -2,7 +2,9 @@
//! //!
//! The files can be obtained from the following link: //! The files can be obtained from the following link:
//! <http://yann.lecun.com/exdb/mnist/> //! <http://yann.lecun.com/exdb/mnist/>
use candle::{DType, Device, Result, Tensor}; use candle::{DType, Device, Error, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File; use std::fs::File;
use std::io::{self, BufReader, Read}; use std::io::{self, BufReader, Read};
@ -63,3 +65,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
labels: 10, labels: 10,
}) })
} }
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
let samples = parquet.metadata().file_metadata().num_rows() as usize;
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
for row in parquet.into_iter().flatten() {
for (_name, field) in row.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_luma8().as_raw());
}
}
} else if let parquet::record::Field::Long(label) = field {
buffer_labels.push(*label as u8);
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
.to_dtype(DType::F32)?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))
}
pub fn load() -> Result<crate::vision::Dataset> {
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo
.get("mnist/mnist-test.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("mnist/mnist-train.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
})
}

View File

@ -138,12 +138,20 @@ struct Args {
/// The file where to load the trained weights from, in safetensors format. /// The file where to load the trained weights from, in safetensors format.
#[arg(long)] #[arg(long)]
load: Option<String>, load: Option<String>,
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
local_mnist: Option<String>,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
// Load the dataset // Load the dataset
let m = candle_datasets::vision::mnist::load_dir("data")?; let m = if let Some(directory) = args.local_mnist {
candle_datasets::vision::mnist::load_dir(directory)?
} else {
candle_datasets::vision::mnist::load()?
};
println!("train-images: {:?}", m.train_images.shape()); println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape()); println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape()); println!("test-images: {:?}", m.test_images.shape());

View File

@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?; image.save(p).map_err(candle::Error::wrap)?;
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_2() {
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}
#[rustfmt::skip]
#[test]
fn book_hub_3() {
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
}