mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Merge pull request #439 from huggingface/training_hub_dataset
[Book] Add small error management + start training (with generic dataset inclusion).
This commit is contained in:
2
.github/workflows/ci_cuda.yaml
vendored
2
.github/workflows/ci_cuda.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
|||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
- run: apt update -y && apt install libssl-dev -y
|
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||||
stop-runner:
|
stop-runner:
|
||||||
|
@ -3,6 +3,7 @@ members = [
|
|||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-datasets",
|
"candle-datasets",
|
||||||
"candle-examples",
|
"candle-examples",
|
||||||
|
"candle-book",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
@ -57,6 +58,7 @@ tracing-chrome = "0.7.1"
|
|||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
parquet = { version = "45.0.0" }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
49
candle-book/Cargo.toml
Normal file
49
candle-book/Cargo.toml
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-book"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
description.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
accelerate-src = { workspace = true, optional = true }
|
||||||
|
candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" }
|
||||||
|
candle-datasets = { path = "../candle-datasets", version = "0.2.0" }
|
||||||
|
candle-nn = { path = "../candle-nn", version = "0.2.0" }
|
||||||
|
candle-transformers = { path = "../candle-transformers", version = "0.2.0" }
|
||||||
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.0", optional = true }
|
||||||
|
safetensors = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
num-traits = { workspace = true }
|
||||||
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
cudarc = { workspace = true, optional = true }
|
||||||
|
half = { workspace = true, optional = true }
|
||||||
|
image = { workspace = true, optional = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
byteorder = { workspace = true }
|
||||||
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
|
clap = { workspace = true }
|
||||||
|
memmap2 = { workspace = true }
|
||||||
|
rand = { workspace = true }
|
||||||
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
tracing-chrome = { workspace = true }
|
||||||
|
tracing-subscriber = { workspace = true }
|
||||||
|
wav = { workspace = true }
|
||||||
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
|
tokio = "1.29.1"
|
||||||
|
parquet = { workspace = true }
|
||||||
|
image = { workspace = true }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
@ -12,7 +12,11 @@
|
|||||||
|
|
||||||
- [Running a model](inference/README.md)
|
- [Running a model](inference/README.md)
|
||||||
- [Using the hub](inference/hub.md)
|
- [Using the hub](inference/hub.md)
|
||||||
- [Error management]()
|
- [Error management](error_manage.md)
|
||||||
|
- [Training](training/README.md)
|
||||||
|
- [MNIST](training/mnist.md)
|
||||||
|
- [Fine-tuning]()
|
||||||
|
- [Serialization]()
|
||||||
- [Advanced Cuda usage]()
|
- [Advanced Cuda usage]()
|
||||||
- [Writing a custom kernel]()
|
- [Writing a custom kernel]()
|
||||||
- [Porting a custom kernel]()
|
- [Porting a custom kernel]()
|
||||||
@ -21,7 +25,3 @@
|
|||||||
- [Creating a WASM app]()
|
- [Creating a WASM app]()
|
||||||
- [Creating a REST api webserver]()
|
- [Creating a REST api webserver]()
|
||||||
- [Creating a desktop Tauri app]()
|
- [Creating a desktop Tauri app]()
|
||||||
- [Training]()
|
|
||||||
- [MNIST]()
|
|
||||||
- [Fine-tuning]()
|
|
||||||
- [Serialization]()
|
|
||||||
|
@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
|
|||||||
```rust,ignore
|
```rust,ignore
|
||||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
|
{{#include ../lib.rs:book_hub_1}}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
|
|||||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||||
|
|
||||||
```rust,ignore
|
```rust,ignore
|
||||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
|
{{#include ../lib.rs:book_hub_2}}
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||||
@ -100,5 +100,5 @@ cargo add safetensors
|
|||||||
|
|
||||||
|
|
||||||
```rust,ignore
|
```rust,ignore
|
||||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
|
{{#include ../lib.rs:book_hub_3}}
|
||||||
```
|
```
|
||||||
|
193
candle-book/src/lib.rs
Normal file
193
candle-book/src/lib.rs
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use parquet::file::reader::SerializedFileReader;
|
||||||
|
|
||||||
|
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn book_hub_1() {
|
||||||
|
// ANCHOR: book_hub_1
|
||||||
|
use candle::Device;
|
||||||
|
use hf_hub::api::tokio::Api;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
|
||||||
|
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||||
|
|
||||||
|
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||||
|
// ANCHOR_END: book_hub_1
|
||||||
|
assert_eq!(weights.len(), 206);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn book_hub_2() {
|
||||||
|
// ANCHOR: book_hub_2
|
||||||
|
use candle::Device;
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use memmap2::Mmap;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||||
|
|
||||||
|
let file = fs::File::open(weights_filename).unwrap();
|
||||||
|
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||||
|
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||||
|
// ANCHOR_END: book_hub_2
|
||||||
|
assert_eq!(weights.len(), 206);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn book_hub_3() {
|
||||||
|
// ANCHOR: book_hub_3
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use memmap2::Mmap;
|
||||||
|
use safetensors::slice::IndexOp;
|
||||||
|
use safetensors::SafeTensors;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||||
|
|
||||||
|
let file = fs::File::open(weights_filename).unwrap();
|
||||||
|
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||||
|
|
||||||
|
// Use safetensors directly
|
||||||
|
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||||
|
let view = tensors
|
||||||
|
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We're going to load shard with rank 1, within a world_size of 4
|
||||||
|
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||||
|
let rank = 1;
|
||||||
|
let world_size = 4;
|
||||||
|
let dim = 0;
|
||||||
|
let dtype = view.dtype();
|
||||||
|
let mut tp_shape = view.shape().to_vec();
|
||||||
|
let size = tp_shape[0];
|
||||||
|
|
||||||
|
if size % world_size != 0 {
|
||||||
|
panic!("The dimension is not divisble by `world_size`");
|
||||||
|
}
|
||||||
|
let block_size = size / world_size;
|
||||||
|
let start = rank * block_size;
|
||||||
|
let stop = (rank + 1) * block_size;
|
||||||
|
|
||||||
|
// Everything is expressed in tensor dimension
|
||||||
|
// bytes offsets is handled automatically for safetensors.
|
||||||
|
|
||||||
|
let iterator = view.slice(start..stop).unwrap();
|
||||||
|
|
||||||
|
tp_shape[dim] = block_size;
|
||||||
|
|
||||||
|
// Convert safetensors Dtype to candle DType
|
||||||
|
let dtype: DType = dtype.try_into().unwrap();
|
||||||
|
|
||||||
|
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||||
|
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||||
|
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||||
|
// ANCHOR_END: book_hub_3
|
||||||
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn book_training_1() -> Result<()>{
|
||||||
|
// ANCHOR: book_training_1
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
|
||||||
|
let dataset_id = "mnist".to_string();
|
||||||
|
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
dataset_id,
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let repo = api.repo(repo);
|
||||||
|
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||||
|
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||||
|
// ANCHOR_END: book_training_1
|
||||||
|
// Ignore unused
|
||||||
|
let _train = train_parquet;
|
||||||
|
// ANCHOR: book_training_2
|
||||||
|
for row in test_parquet {
|
||||||
|
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||||
|
println!("Column id {idx}, name {name}, value {field}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// ANCHOR_END: book_training_2
|
||||||
|
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||||
|
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||||
|
// ANCHOR: book_training_3
|
||||||
|
|
||||||
|
let test_samples = 10_000;
|
||||||
|
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||||
|
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||||
|
for row in test_parquet{
|
||||||
|
for (_name, field) in row?.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
|
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}else if let parquet::record::Field::Long(label) = field {
|
||||||
|
test_buffer_labels.push(*label as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||||
|
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let train_samples = 60_000;
|
||||||
|
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||||
|
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||||
|
for row in train_parquet{
|
||||||
|
for (_name, field) in row?.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
|
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}else if let parquet::record::Field::Long(label) = field {
|
||||||
|
train_buffer_labels.push(*label as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||||
|
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mnist = candle_datasets::vision::Dataset {
|
||||||
|
train_images,
|
||||||
|
train_labels,
|
||||||
|
test_images,
|
||||||
|
test_labels,
|
||||||
|
labels: 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
// ANCHOR_END: book_training_3
|
||||||
|
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||||
|
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||||
|
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||||
|
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -1 +1,39 @@
|
|||||||
# Training
|
# Training
|
||||||
|
|
||||||
|
|
||||||
|
Training starts with data. We're going to use the huggingface hub and
|
||||||
|
start with the Hello world dataset of machine learning, MNIST.
|
||||||
|
|
||||||
|
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
||||||
|
|
||||||
|
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
|
||||||
|
```bash
|
||||||
|
cargo add hf-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
This is going to be very hands-on for now.
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
|
||||||
|
```
|
||||||
|
|
||||||
|
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
||||||
|
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||||
|
|
||||||
|
We can inspect the content of the files with:
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see something like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
Column id 1, name label, value 6
|
||||||
|
Column id 0, name image, value {bytes: [137, ....]
|
||||||
|
Column id 1, name label, value 8
|
||||||
|
Column id 0, name image, value {bytes: [137, ....]
|
||||||
|
```
|
||||||
|
|
||||||
|
So each row contains 2 columns (image, label) with image being saved as bytes.
|
||||||
|
Let's put them into a useful struct.
|
||||||
|
@ -1 +1,10 @@
|
|||||||
# MNIST
|
# MNIST
|
||||||
|
|
||||||
|
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
{{#include ../lib.rs:book_training_3}}
|
||||||
|
```
|
||||||
|
|
||||||
|
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||||
|
It is quite rudimentary, but simple enough for a small dataset like MNIST.
|
||||||
|
@ -18,3 +18,6 @@ intel-mkl-src = { workspace = true, optional = true }
|
|||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
|
thiserror = { workspace = true }
|
||||||
|
parquet = { workspace = true}
|
||||||
|
image = { workspace = true }
|
||||||
|
73
candle-datasets/src/hub.rs
Normal file
73
candle-datasets/src/hub.rs
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
use hf_hub::{
|
||||||
|
api::sync::{Api, ApiRepo},
|
||||||
|
Repo, RepoType,
|
||||||
|
};
|
||||||
|
use parquet::file::reader::SerializedFileReader;
|
||||||
|
use std::fs::File;
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("ApiError : {0}")]
|
||||||
|
ApiError(#[from] hf_hub::api::sync::ApiError),
|
||||||
|
|
||||||
|
#[error("IoError : {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("ParquetError : {0}")]
|
||||||
|
ParquetError(#[from] parquet::errors::ParquetError),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sibling_to_parquet(
|
||||||
|
rfilename: &str,
|
||||||
|
repo: &ApiRepo,
|
||||||
|
) -> Result<SerializedFileReader<File>, Error> {
|
||||||
|
let local = repo.get(rfilename)?;
|
||||||
|
let file = File::open(local)?;
|
||||||
|
let reader = SerializedFileReader::new(file)?;
|
||||||
|
Ok(reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
dataset_id,
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let repo = api.repo(repo);
|
||||||
|
let info = repo.info()?;
|
||||||
|
|
||||||
|
let files: Result<Vec<_>, _> = info
|
||||||
|
.siblings
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|s| -> Option<Result<_, _>> {
|
||||||
|
let filename = s.rfilename;
|
||||||
|
if filename.ends_with(".parquet") {
|
||||||
|
let reader_result = sibling_to_parquet(&filename, &repo);
|
||||||
|
Some(reader_result)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let files = files?;
|
||||||
|
|
||||||
|
Ok(files)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use parquet::file::reader::FileReader;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dataset() {
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let files = from_hub(
|
||||||
|
&api,
|
||||||
|
"hf-internal-testing/dummy_image_text_data".to_string(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(files.len(), 1);
|
||||||
|
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,6 @@
|
|||||||
//! Datasets & Dataloaders for Candle
|
//! Datasets & Dataloaders for Candle
|
||||||
pub mod batcher;
|
pub mod batcher;
|
||||||
|
pub mod hub;
|
||||||
pub mod nlp;
|
pub mod nlp;
|
||||||
pub mod vision;
|
pub mod vision;
|
||||||
|
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
//!
|
//!
|
||||||
//! The files can be obtained from the following link:
|
//! The files can be obtained from the following link:
|
||||||
//! <http://yann.lecun.com/exdb/mnist/>
|
//! <http://yann.lecun.com/exdb/mnist/>
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Error, Result, Tensor};
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufReader, Read};
|
use std::io::{self, BufReader, Read};
|
||||||
|
|
||||||
@ -63,3 +65,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
|
|||||||
labels: 10,
|
labels: 10,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||||
|
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||||
|
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
|
||||||
|
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||||
|
for row in parquet.into_iter().flatten() {
|
||||||
|
for (_name, field) in row.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
|
buffer_images.extend(image.to_luma8().as_raw());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if let parquet::record::Field::Long(label) = field {
|
||||||
|
buffer_labels.push(*label as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
/ 255.)?;
|
||||||
|
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||||
|
Ok((images, labels))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load() -> Result<crate::vision::Dataset> {
|
||||||
|
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let dataset_id = "mnist".to_string();
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
dataset_id,
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let repo = api.repo(repo);
|
||||||
|
let test_parquet_filename = repo
|
||||||
|
.get("mnist/mnist-test.parquet")
|
||||||
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let train_parquet_filename = repo
|
||||||
|
.get("mnist/mnist-train.parquet")
|
||||||
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||||
|
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||||
|
Ok(crate::vision::Dataset {
|
||||||
|
train_images,
|
||||||
|
train_labels,
|
||||||
|
test_images,
|
||||||
|
test_labels,
|
||||||
|
labels: 10,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -138,12 +138,20 @@ struct Args {
|
|||||||
/// The file where to load the trained weights from, in safetensors format.
|
/// The file where to load the trained weights from, in safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
load: Option<String>,
|
load: Option<String>,
|
||||||
|
|
||||||
|
/// The file where to load the trained weights from, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
local_mnist: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Load the dataset
|
// Load the dataset
|
||||||
let m = candle_datasets::vision::mnist::load_dir("data")?;
|
let m = if let Some(directory) = args.local_mnist {
|
||||||
|
candle_datasets::vision::mnist::load_dir(directory)?
|
||||||
|
} else {
|
||||||
|
candle_datasets::vision::mnist::load()?
|
||||||
|
};
|
||||||
println!("train-images: {:?}", m.train_images.shape());
|
println!("train-images: {:?}", m.train_images.shape());
|
||||||
println!("train-labels: {:?}", m.train_labels.shape());
|
println!("train-labels: {:?}", m.train_labels.shape());
|
||||||
println!("test-images: {:?}", m.test_images.shape());
|
println!("test-images: {:?}", m.test_images.shape());
|
||||||
|
@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
|||||||
image.save(p).map_err(candle::Error::wrap)?;
|
image.save(p).map_err(candle::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
|
||||||
#[rustfmt::skip]
|
|
||||||
#[tokio::test]
|
|
||||||
async fn book_hub_1() {
|
|
||||||
// ANCHOR: book_hub_1
|
|
||||||
use candle::Device;
|
|
||||||
use hf_hub::api::tokio::Api;
|
|
||||||
|
|
||||||
let api = Api::new().unwrap();
|
|
||||||
let repo = api.model("bert-base-uncased".to_string());
|
|
||||||
|
|
||||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
|
||||||
|
|
||||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
|
||||||
// ANCHOR_END: book_hub_1
|
|
||||||
assert_eq!(weights.len(), 206);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
#[test]
|
|
||||||
fn book_hub_2() {
|
|
||||||
// ANCHOR: book_hub_2
|
|
||||||
use candle::Device;
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
use memmap2::Mmap;
|
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
let api = Api::new().unwrap();
|
|
||||||
let repo = api.model("bert-base-uncased".to_string());
|
|
||||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
|
||||||
|
|
||||||
let file = fs::File::open(weights_filename).unwrap();
|
|
||||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
|
||||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
|
||||||
// ANCHOR_END: book_hub_2
|
|
||||||
assert_eq!(weights.len(), 206);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
#[test]
|
|
||||||
fn book_hub_3() {
|
|
||||||
// ANCHOR: book_hub_3
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
use memmap2::Mmap;
|
|
||||||
use safetensors::slice::IndexOp;
|
|
||||||
use safetensors::SafeTensors;
|
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
let api = Api::new().unwrap();
|
|
||||||
let repo = api.model("bert-base-uncased".to_string());
|
|
||||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
|
||||||
|
|
||||||
let file = fs::File::open(weights_filename).unwrap();
|
|
||||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
|
||||||
|
|
||||||
// Use safetensors directly
|
|
||||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
|
||||||
let view = tensors
|
|
||||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// We're going to load shard with rank 1, within a world_size of 4
|
|
||||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
|
||||||
let rank = 1;
|
|
||||||
let world_size = 4;
|
|
||||||
let dim = 0;
|
|
||||||
let dtype = view.dtype();
|
|
||||||
let mut tp_shape = view.shape().to_vec();
|
|
||||||
let size = tp_shape[0];
|
|
||||||
|
|
||||||
if size % world_size != 0 {
|
|
||||||
panic!("The dimension is not divisble by `world_size`");
|
|
||||||
}
|
|
||||||
let block_size = size / world_size;
|
|
||||||
let start = rank * block_size;
|
|
||||||
let stop = (rank + 1) * block_size;
|
|
||||||
|
|
||||||
// Everything is expressed in tensor dimension
|
|
||||||
// bytes offsets is handled automatically for safetensors.
|
|
||||||
|
|
||||||
let iterator = view.slice(start..stop).unwrap();
|
|
||||||
|
|
||||||
tp_shape[dim] = block_size;
|
|
||||||
|
|
||||||
// Convert safetensors Dtype to candle DType
|
|
||||||
let dtype: DType = dtype.try_into().unwrap();
|
|
||||||
|
|
||||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
|
||||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
|
||||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
|
||||||
// ANCHOR_END: book_hub_3
|
|
||||||
assert_eq!(view.shape(), &[768, 768]);
|
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Reference in New Issue
Block a user