mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Cleanup:
- Moved around book from `examples` to `candle-book` proper (overlapping the book and the lib structures)
This commit is contained in:
@ -3,6 +3,7 @@ members = [
|
|||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-datasets",
|
"candle-datasets",
|
||||||
"candle-examples",
|
"candle-examples",
|
||||||
|
"candle-book",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
@ -57,6 +58,7 @@ tracing-chrome = "0.7.1"
|
|||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
|
parquet = { version = "45.0.0" }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
49
candle-book/Cargo.toml
Normal file
49
candle-book/Cargo.toml
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-book"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
description.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
accelerate-src = { workspace = true, optional = true }
|
||||||
|
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||||
|
candle-datasets = { path = "../candle-datasets", version = "0.1.0" }
|
||||||
|
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||||
|
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
||||||
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
||||||
|
safetensors = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
num-traits = { workspace = true }
|
||||||
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
cudarc = { workspace = true, optional = true }
|
||||||
|
half = { workspace = true, optional = true }
|
||||||
|
image = { workspace = true, optional = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
byteorder = { workspace = true }
|
||||||
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
|
clap = { workspace = true }
|
||||||
|
memmap2 = { workspace = true }
|
||||||
|
rand = { workspace = true }
|
||||||
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
tracing-chrome = { workspace = true }
|
||||||
|
tracing-subscriber = { workspace = true }
|
||||||
|
wav = { workspace = true }
|
||||||
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
|
tokio = "1.29.1"
|
||||||
|
parquet = { workspace = true }
|
||||||
|
image = { workspace = true }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
|
|||||||
```rust,ignore
|
```rust,ignore
|
||||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
|
{{#include ../lib.rs:book_hub_1}}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
|
|||||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||||
|
|
||||||
```rust,ignore
|
```rust,ignore
|
||||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
|
{{#include ../lib.rs:book_hub_2}}
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||||
@ -100,5 +100,5 @@ cargo add safetensors
|
|||||||
|
|
||||||
|
|
||||||
```rust,ignore
|
```rust,ignore
|
||||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
|
{{#include ../lib.rs:book_hub_3}}
|
||||||
```
|
```
|
||||||
|
193
candle-book/src/lib.rs
Normal file
193
candle-book/src/lib.rs
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use parquet::file::reader::SerializedFileReader;
|
||||||
|
|
||||||
|
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn book_hub_1() {
|
||||||
|
// ANCHOR: book_hub_1
|
||||||
|
use candle::Device;
|
||||||
|
use hf_hub::api::tokio::Api;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
|
||||||
|
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||||
|
|
||||||
|
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||||
|
// ANCHOR_END: book_hub_1
|
||||||
|
assert_eq!(weights.len(), 206);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn book_hub_2() {
|
||||||
|
// ANCHOR: book_hub_2
|
||||||
|
use candle::Device;
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use memmap2::Mmap;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||||
|
|
||||||
|
let file = fs::File::open(weights_filename).unwrap();
|
||||||
|
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||||
|
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||||
|
// ANCHOR_END: book_hub_2
|
||||||
|
assert_eq!(weights.len(), 206);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn book_hub_3() {
|
||||||
|
// ANCHOR: book_hub_3
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use memmap2::Mmap;
|
||||||
|
use safetensors::slice::IndexOp;
|
||||||
|
use safetensors::SafeTensors;
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||||
|
|
||||||
|
let file = fs::File::open(weights_filename).unwrap();
|
||||||
|
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||||
|
|
||||||
|
// Use safetensors directly
|
||||||
|
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||||
|
let view = tensors
|
||||||
|
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// We're going to load shard with rank 1, within a world_size of 4
|
||||||
|
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||||
|
let rank = 1;
|
||||||
|
let world_size = 4;
|
||||||
|
let dim = 0;
|
||||||
|
let dtype = view.dtype();
|
||||||
|
let mut tp_shape = view.shape().to_vec();
|
||||||
|
let size = tp_shape[0];
|
||||||
|
|
||||||
|
if size % world_size != 0 {
|
||||||
|
panic!("The dimension is not divisble by `world_size`");
|
||||||
|
}
|
||||||
|
let block_size = size / world_size;
|
||||||
|
let start = rank * block_size;
|
||||||
|
let stop = (rank + 1) * block_size;
|
||||||
|
|
||||||
|
// Everything is expressed in tensor dimension
|
||||||
|
// bytes offsets is handled automatically for safetensors.
|
||||||
|
|
||||||
|
let iterator = view.slice(start..stop).unwrap();
|
||||||
|
|
||||||
|
tp_shape[dim] = block_size;
|
||||||
|
|
||||||
|
// Convert safetensors Dtype to candle DType
|
||||||
|
let dtype: DType = dtype.try_into().unwrap();
|
||||||
|
|
||||||
|
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||||
|
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||||
|
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||||
|
// ANCHOR_END: book_hub_3
|
||||||
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn book_training_1() -> Result<()>{
|
||||||
|
// ANCHOR: book_training_1
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
|
||||||
|
let dataset_id = "mnist".to_string();
|
||||||
|
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
dataset_id,
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let repo = api.repo(repo);
|
||||||
|
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||||
|
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||||
|
// ANCHOR_END: book_training_1
|
||||||
|
// Ignore unused
|
||||||
|
let _train = train_parquet;
|
||||||
|
// ANCHOR: book_training_2
|
||||||
|
for row in test_parquet {
|
||||||
|
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||||
|
println!("Column id {idx}, name {name}, value {field}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// ANCHOR_END: book_training_2
|
||||||
|
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||||
|
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||||
|
// ANCHOR: book_training_3
|
||||||
|
|
||||||
|
let test_samples = 10_000;
|
||||||
|
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||||
|
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||||
|
for row in test_parquet{
|
||||||
|
for (_name, field) in row?.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
|
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}else if let parquet::record::Field::Long(label) = field {
|
||||||
|
test_buffer_labels.push(*label as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||||
|
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let train_samples = 60_000;
|
||||||
|
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||||
|
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||||
|
for row in train_parquet{
|
||||||
|
for (_name, field) in row?.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
|
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}else if let parquet::record::Field::Long(label) = field {
|
||||||
|
train_buffer_labels.push(*label as u8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||||
|
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mnist = candle_datasets::vision::Dataset {
|
||||||
|
train_images,
|
||||||
|
train_labels,
|
||||||
|
test_images,
|
||||||
|
test_labels,
|
||||||
|
labels: 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
// ANCHOR_END: book_training_3
|
||||||
|
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||||
|
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||||
|
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||||
|
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -3,7 +3,7 @@
|
|||||||
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||||
|
|
||||||
```rust,ignore
|
```rust,ignore
|
||||||
{{#include ../../../candle-examples/src/lib.rs:book_training_3}}
|
{{#include ../lib.rs:book_training_3}}
|
||||||
```
|
```
|
||||||
|
|
||||||
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||||
|
@ -9,6 +9,9 @@ categories.workspace = true
|
|||||||
license.workspace = true
|
license.workspace = true
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
crate-type=["dylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" }
|
||||||
|
@ -81,14 +81,14 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Keep for now to showcase how to use.
|
// Keep for now to showcase how to use.
|
||||||
//for row in rows {
|
for row in rows {
|
||||||
// if let Ok(row) = row {
|
if let Ok(row) = row {
|
||||||
// for (_idx, (_name, field)) in row.get_column_iter().enumerate() {
|
for (_idx, (_name, field)) in row.get_column_iter().enumerate() {
|
||||||
// if let parquet::record::Field::Str(value) = field {
|
if let parquet::record::Field::Str(value) = field {
|
||||||
// println!("Value {value:?}");
|
println!("Value {value:?}");
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
//}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,234 +15,3 @@ pub fn device(cpu: bool) -> Result<Device> {
|
|||||||
Ok(device)
|
Ok(device)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
|
||||||
p: P,
|
|
||||||
width: usize,
|
|
||||||
height: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let img = image::io::Reader::open(p)?
|
|
||||||
.decode()
|
|
||||||
.map_err(candle::Error::wrap)?
|
|
||||||
.resize_to_fill(
|
|
||||||
width as u32,
|
|
||||||
height as u32,
|
|
||||||
image::imageops::FilterType::Triangle,
|
|
||||||
);
|
|
||||||
let img = img.to_rgb8();
|
|
||||||
let data = img.into_raw();
|
|
||||||
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
|
||||||
/// (c, width, height).
|
|
||||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
|
||||||
let p = p.as_ref();
|
|
||||||
let (channel, width, height) = img.dims3()?;
|
|
||||||
if channel != 3 {
|
|
||||||
candle::bail!("save_image expects an input of shape (3, width, height)")
|
|
||||||
}
|
|
||||||
let img = img.transpose(0, 1)?.t()?.flatten_all()?;
|
|
||||||
let pixels = img.to_vec1::<u8>()?;
|
|
||||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
|
||||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
|
||||||
Some(image) => image,
|
|
||||||
None => candle::bail!("error saving image {p:?}"),
|
|
||||||
};
|
|
||||||
image.save(p).map_err(candle::Error::wrap)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use anyhow::Result;
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use parquet::file::reader::SerializedFileReader;
|
|
||||||
|
|
||||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
|
||||||
#[rustfmt::skip]
|
|
||||||
#[tokio::test]
|
|
||||||
async fn book_hub_1() {
|
|
||||||
// ANCHOR: book_hub_1
|
|
||||||
use candle::Device;
|
|
||||||
use hf_hub::api::tokio::Api;
|
|
||||||
|
|
||||||
let api = Api::new().unwrap();
|
|
||||||
let repo = api.model("bert-base-uncased".to_string());
|
|
||||||
|
|
||||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
|
||||||
|
|
||||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
|
||||||
// ANCHOR_END: book_hub_1
|
|
||||||
assert_eq!(weights.len(), 206);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
#[test]
|
|
||||||
fn book_hub_2() {
|
|
||||||
// ANCHOR: book_hub_2
|
|
||||||
use candle::Device;
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
use memmap2::Mmap;
|
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
let api = Api::new().unwrap();
|
|
||||||
let repo = api.model("bert-base-uncased".to_string());
|
|
||||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
|
||||||
|
|
||||||
let file = fs::File::open(weights_filename).unwrap();
|
|
||||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
|
||||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
|
||||||
// ANCHOR_END: book_hub_2
|
|
||||||
assert_eq!(weights.len(), 206);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
#[test]
|
|
||||||
fn book_hub_3() {
|
|
||||||
// ANCHOR: book_hub_3
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
use memmap2::Mmap;
|
|
||||||
use safetensors::slice::IndexOp;
|
|
||||||
use safetensors::SafeTensors;
|
|
||||||
use std::fs;
|
|
||||||
|
|
||||||
let api = Api::new().unwrap();
|
|
||||||
let repo = api.model("bert-base-uncased".to_string());
|
|
||||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
|
||||||
|
|
||||||
let file = fs::File::open(weights_filename).unwrap();
|
|
||||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
|
||||||
|
|
||||||
// Use safetensors directly
|
|
||||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
|
||||||
let view = tensors
|
|
||||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// We're going to load shard with rank 1, within a world_size of 4
|
|
||||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
|
||||||
let rank = 1;
|
|
||||||
let world_size = 4;
|
|
||||||
let dim = 0;
|
|
||||||
let dtype = view.dtype();
|
|
||||||
let mut tp_shape = view.shape().to_vec();
|
|
||||||
let size = tp_shape[0];
|
|
||||||
|
|
||||||
if size % world_size != 0 {
|
|
||||||
panic!("The dimension is not divisble by `world_size`");
|
|
||||||
}
|
|
||||||
let block_size = size / world_size;
|
|
||||||
let start = rank * block_size;
|
|
||||||
let stop = (rank + 1) * block_size;
|
|
||||||
|
|
||||||
// Everything is expressed in tensor dimension
|
|
||||||
// bytes offsets is handled automatically for safetensors.
|
|
||||||
|
|
||||||
let iterator = view.slice(start..stop).unwrap();
|
|
||||||
|
|
||||||
tp_shape[dim] = block_size;
|
|
||||||
|
|
||||||
// Convert safetensors Dtype to candle DType
|
|
||||||
let dtype: DType = dtype.try_into().unwrap();
|
|
||||||
|
|
||||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
|
||||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
|
||||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
|
||||||
// ANCHOR_END: book_hub_3
|
|
||||||
assert_eq!(view.shape(), &[768, 768]);
|
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[rustfmt::skip]
|
|
||||||
#[test]
|
|
||||||
fn book_training_1() -> Result<()>{
|
|
||||||
// ANCHOR: book_training_1
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
|
|
||||||
let dataset_id = "mnist".to_string();
|
|
||||||
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = Repo::with_revision(
|
|
||||||
dataset_id,
|
|
||||||
RepoType::Dataset,
|
|
||||||
"refs/convert/parquet".to_string(),
|
|
||||||
);
|
|
||||||
let repo = api.repo(repo);
|
|
||||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
|
||||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
|
||||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
|
||||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
|
||||||
// ANCHOR_END: book_training_1
|
|
||||||
// Ignore unused
|
|
||||||
let _train = train_parquet;
|
|
||||||
// ANCHOR: book_training_2
|
|
||||||
for row in test_parquet {
|
|
||||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
|
||||||
println!("Column id {idx}, name {name}, value {field}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// ANCHOR_END: book_training_2
|
|
||||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
|
||||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
|
||||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
|
||||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
|
||||||
// ANCHOR: book_training_3
|
|
||||||
|
|
||||||
let test_samples = 10_000;
|
|
||||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
|
||||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
|
||||||
for row in test_parquet{
|
|
||||||
for (_name, field) in row?.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Group(subrow) = field {
|
|
||||||
for (_name, field) in subrow.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Bytes(value) = field {
|
|
||||||
let image = image::load_from_memory(value.data()).unwrap();
|
|
||||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}else if let parquet::record::Field::Long(label) = field {
|
|
||||||
test_buffer_labels.push(*label as u8);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
|
||||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
|
||||||
|
|
||||||
let train_samples = 60_000;
|
|
||||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
|
||||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
|
||||||
for row in train_parquet{
|
|
||||||
for (_name, field) in row?.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Group(subrow) = field {
|
|
||||||
for (_name, field) in subrow.get_column_iter() {
|
|
||||||
if let parquet::record::Field::Bytes(value) = field {
|
|
||||||
let image = image::load_from_memory(value.data()).unwrap();
|
|
||||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}else if let parquet::record::Field::Long(label) = field {
|
|
||||||
train_buffer_labels.push(*label as u8);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
|
||||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
|
||||||
|
|
||||||
let mnist = candle_datasets::vision::Dataset {
|
|
||||||
train_images,
|
|
||||||
train_labels,
|
|
||||||
test_images,
|
|
||||||
test_labels,
|
|
||||||
labels: 10,
|
|
||||||
};
|
|
||||||
|
|
||||||
// ANCHOR_END: book_training_3
|
|
||||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
|
||||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
|
||||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
|
||||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Reference in New Issue
Block a user