mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Merge pull request #439 from huggingface/training_hub_dataset
[Book] Add small error management + start training (with generic dataset inclusion).
This commit is contained in:
2
.github/workflows/ci_cuda.yaml
vendored
2
.github/workflows/ci_cuda.yaml
vendored
@ -59,7 +59,7 @@ jobs:
|
||||
- name: Install Rust Stable
|
||||
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: apt update -y && apt install libssl-dev -y
|
||||
- run: apt-get update -y && apt-get install libssl-dev -y
|
||||
- name: Test (cuda)
|
||||
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||
stop-runner:
|
||||
|
@ -3,6 +3,7 @@ members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-book",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
@ -57,6 +58,7 @@ tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
49
candle-book/Cargo.toml
Normal file
49
candle-book/Cargo.toml
Normal file
@ -0,0 +1,49 @@
|
||||
[package]
|
||||
name = "candle-book"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
@ -12,7 +12,11 @@
|
||||
|
||||
- [Running a model](inference/README.md)
|
||||
- [Using the hub](inference/hub.md)
|
||||
- [Error management]()
|
||||
- [Error management](error_manage.md)
|
||||
- [Training](training/README.md)
|
||||
- [MNIST](training/mnist.md)
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
- [Advanced Cuda usage]()
|
||||
- [Writing a custom kernel]()
|
||||
- [Porting a custom kernel]()
|
||||
@ -21,7 +25,3 @@
|
||||
- [Creating a WASM app]()
|
||||
- [Creating a REST api webserver]()
|
||||
- [Creating a desktop Tauri app]()
|
||||
- [Training]()
|
||||
- [MNIST]()
|
||||
- [Fine-tuning]()
|
||||
- [Serialization]()
|
||||
|
@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
|
||||
```rust,ignore
|
||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
|
||||
{{#include ../lib.rs:book_hub_1}}
|
||||
```
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
|
||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
|
||||
{{#include ../lib.rs:book_hub_2}}
|
||||
```
|
||||
|
||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||
@ -100,5 +100,5 @@ cargo add safetensors
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
|
||||
{{#include ../lib.rs:book_hub_3}}
|
||||
```
|
||||
|
193
candle-book/src/lib.rs
Normal file
193
candle-book/src/lib.rs
Normal file
@ -0,0 +1,193 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
let dataset_id = "mnist".to_string();
|
||||
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// Ignore unused
|
||||
let _train = train_parquet;
|
||||
// ANCHOR: book_training_2
|
||||
for row in test_parquet {
|
||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
let test_parquet_filename = repo.get("mnist/test/0000.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/train/0000.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR: book_training_3
|
||||
|
||||
let test_samples = 10_000;
|
||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||
for row in test_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
test_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||
|
||||
let train_samples = 60_000;
|
||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||
for row in train_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
train_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||
|
||||
let mnist = candle_datasets::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
};
|
||||
|
||||
// ANCHOR_END: book_training_3
|
||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1 +1,39 @@
|
||||
# Training
|
||||
|
||||
|
||||
Training starts with data. We're going to use the huggingface hub and
|
||||
start with the Hello world dataset of machine learning, MNIST.
|
||||
|
||||
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
||||
|
||||
This requires [`hf-hub`](https://github.com/huggingface/hf-hub).
|
||||
```bash
|
||||
cargo add hf-hub
|
||||
```
|
||||
|
||||
This is going to be very hands-on for now.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
|
||||
```
|
||||
|
||||
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
||||
Our handles are now [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||
|
||||
We can inspect the content of the files with:
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
|
||||
```
|
||||
|
||||
You should see something like:
|
||||
|
||||
```bash
|
||||
Column id 1, name label, value 6
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
Column id 1, name label, value 8
|
||||
Column id 0, name image, value {bytes: [137, ....]
|
||||
```
|
||||
|
||||
So each row contains 2 columns (image, label) with image being saved as bytes.
|
||||
Let's put them into a useful struct.
|
||||
|
@ -1 +1,10 @@
|
||||
# MNIST
|
||||
|
||||
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../lib.rs:book_training_3}}
|
||||
```
|
||||
|
||||
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||
It is quite rudimentary, but simple enough for a small dataset like MNIST.
|
||||
|
@ -18,3 +18,6 @@ intel-mkl-src = { workspace = true, optional = true }
|
||||
memmap2 = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
rand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
parquet = { workspace = true}
|
||||
image = { workspace = true }
|
||||
|
73
candle-datasets/src/hub.rs
Normal file
73
candle-datasets/src/hub.rs
Normal file
@ -0,0 +1,73 @@
|
||||
use hf_hub::{
|
||||
api::sync::{Api, ApiRepo},
|
||||
Repo, RepoType,
|
||||
};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
use std::fs::File;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("ApiError : {0}")]
|
||||
ApiError(#[from] hf_hub::api::sync::ApiError),
|
||||
|
||||
#[error("IoError : {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[error("ParquetError : {0}")]
|
||||
ParquetError(#[from] parquet::errors::ParquetError),
|
||||
}
|
||||
|
||||
fn sibling_to_parquet(
|
||||
rfilename: &str,
|
||||
repo: &ApiRepo,
|
||||
) -> Result<SerializedFileReader<File>, Error> {
|
||||
let local = repo.get(rfilename)?;
|
||||
let file = File::open(local)?;
|
||||
let reader = SerializedFileReader::new(file)?;
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
pub fn from_hub(api: &Api, dataset_id: String) -> Result<Vec<SerializedFileReader<File>>, Error> {
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let info = repo.info()?;
|
||||
|
||||
let files: Result<Vec<_>, _> = info
|
||||
.siblings
|
||||
.into_iter()
|
||||
.filter_map(|s| -> Option<Result<_, _>> {
|
||||
let filename = s.rfilename;
|
||||
if filename.ends_with(".parquet") {
|
||||
let reader_result = sibling_to_parquet(&filename, &repo);
|
||||
Some(reader_result)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let files = files?;
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use parquet::file::reader::FileReader;
|
||||
|
||||
#[test]
|
||||
fn test_dataset() {
|
||||
let api = Api::new().unwrap();
|
||||
let files = from_hub(
|
||||
&api,
|
||||
"hf-internal-testing/dummy_image_text_data".to_string(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(files.len(), 1);
|
||||
assert_eq!(files[0].metadata().file_metadata().num_rows(), 20);
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
//! Datasets & Dataloaders for Candle
|
||||
pub mod batcher;
|
||||
pub mod hub;
|
||||
pub mod nlp;
|
||||
pub mod vision;
|
||||
|
||||
|
@ -2,7 +2,9 @@
|
||||
//!
|
||||
//! The files can be obtained from the following link:
|
||||
//! <http://yann.lecun.com/exdb/mnist/>
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle::{DType, Device, Error, Result, Tensor};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, Read};
|
||||
|
||||
@ -63,3 +65,58 @@ pub fn load_dir<T: AsRef<std::path::Path>>(dir: T) -> Result<crate::vision::Data
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor, Tensor)> {
|
||||
let samples = parquet.metadata().file_metadata().num_rows() as usize;
|
||||
let mut buffer_images: Vec<u8> = Vec::with_capacity(samples * 784);
|
||||
let mut buffer_labels: Vec<u8> = Vec::with_capacity(samples);
|
||||
for row in parquet.into_iter().flatten() {
|
||||
for (_name, field) in row.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
} else if let parquet::record::Field::Long(label) = field {
|
||||
buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 784), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
}
|
||||
|
||||
pub fn load() -> Result<crate::vision::Dataset> {
|
||||
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let dataset_id = "mnist".to_string();
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo
|
||||
.get("mnist/mnist-test.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let train_parquet_filename = repo
|
||||
.get("mnist/mnist-train.parquet")
|
||||
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||
Ok(crate::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
})
|
||||
}
|
||||
|
@ -138,12 +138,20 @@ struct Args {
|
||||
/// The file where to load the trained weights from, in safetensors format.
|
||||
#[arg(long)]
|
||||
load: Option<String>,
|
||||
|
||||
/// The file where to load the trained weights from, in safetensors format.
|
||||
#[arg(long)]
|
||||
local_mnist: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
// Load the dataset
|
||||
let m = candle_datasets::vision::mnist::load_dir("data")?;
|
||||
let m = if let Some(directory) = args.local_mnist {
|
||||
candle_datasets::vision::mnist::load_dir(directory)?
|
||||
} else {
|
||||
candle_datasets::vision::mnist::load()?
|
||||
};
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
|
@ -52,102 +52,3 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
image.save(p).map_err(candle::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user