mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Cleanup:
- Moved around book from `examples` to `candle-book` proper (overlapping the book and the lib structures)
This commit is contained in:
@ -3,6 +3,7 @@ members = [
|
||||
"candle-core",
|
||||
"candle-datasets",
|
||||
"candle-examples",
|
||||
"candle-book",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
@ -57,6 +58,7 @@ tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
parquet = { version = "45.0.0" }
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
49
candle-book/Cargo.toml
Normal file
49
candle-book/Cargo.toml
Normal file
@ -0,0 +1,49 @@
|
||||
[package]
|
||||
name = "candle-book"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.1.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.1.0" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
hf-hub = { workspace = true, features=["tokio"]}
|
||||
clap = { workspace = true }
|
||||
memmap2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
tokio = "1.29.1"
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
@ -39,7 +39,7 @@ cargo add hf-hub --features tokio
|
||||
```rust,ignore
|
||||
# This is tested directly in examples crate because it needs external dependencies unfortunately:
|
||||
# See [this](https://github.com/rust-lang/mdBook/issues/706)
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
|
||||
{{#include ../lib.rs:book_hub_1}}
|
||||
```
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ For more efficient loading, instead of reading the file, you could use [`memmap2
|
||||
and will definitely be slower on network mounted disk, because it will issue more read calls.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
|
||||
{{#include ../lib.rs:book_hub_2}}
|
||||
```
|
||||
|
||||
**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
|
||||
@ -100,5 +100,5 @@ cargo add safetensors
|
||||
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
|
||||
{{#include ../lib.rs:book_hub_3}}
|
||||
```
|
||||
|
193
candle-book/src/lib.rs
Normal file
193
candle-book/src/lib.rs
Normal file
@ -0,0 +1,193 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
let dataset_id = "mnist".to_string();
|
||||
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// Ignore unused
|
||||
let _train = train_parquet;
|
||||
// ANCHOR: book_training_2
|
||||
for row in test_parquet {
|
||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR: book_training_3
|
||||
|
||||
let test_samples = 10_000;
|
||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||
for row in test_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
test_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||
|
||||
let train_samples = 60_000;
|
||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||
for row in train_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
train_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||
|
||||
let mnist = candle_datasets::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
};
|
||||
|
||||
// ANCHOR_END: book_training_3
|
||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -3,7 +3,7 @@
|
||||
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
|
||||
|
||||
```rust,ignore
|
||||
{{#include ../../../candle-examples/src/lib.rs:book_training_3}}
|
||||
{{#include ../lib.rs:book_training_3}}
|
||||
```
|
||||
|
||||
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
|
||||
|
@ -9,6 +9,9 @@ categories.workspace = true
|
||||
license.workspace = true
|
||||
readme = "README.md"
|
||||
|
||||
[lib]
|
||||
crate-type=["dylib"]
|
||||
|
||||
[dependencies]
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.2.0", package = "candle-core" }
|
||||
|
@ -81,14 +81,14 @@ mod tests {
|
||||
);
|
||||
|
||||
// Keep for now to showcase how to use.
|
||||
//for row in rows {
|
||||
// if let Ok(row) = row {
|
||||
// for (_idx, (_name, field)) in row.get_column_iter().enumerate() {
|
||||
// if let parquet::record::Field::Str(value) = field {
|
||||
// println!("Value {value:?}");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
for row in rows {
|
||||
if let Ok(row) = row {
|
||||
for (_idx, (_name, field)) in row.get_column_iter().enumerate() {
|
||||
if let parquet::record::Field::Str(value) = field {
|
||||
println!("Value {value:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,234 +15,3 @@ pub fn device(cpu: bool) -> Result<Device> {
|
||||
Ok(device)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
width: usize,
|
||||
height: usize,
|
||||
) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Triangle,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
|
||||
}
|
||||
|
||||
/// Saves an image to disk using the image crate, this expects an input with shape
|
||||
/// (c, width, height).
|
||||
pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
let p = p.as_ref();
|
||||
let (channel, width, height) = img.dims3()?;
|
||||
if channel != 3 {
|
||||
candle::bail!("save_image expects an input of shape (3, width, height)")
|
||||
}
|
||||
let img = img.transpose(0, 1)?.t()?.flatten_all()?;
|
||||
let pixels = img.to_vec1::<u8>()?;
|
||||
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||
Some(image) => image,
|
||||
None => candle::bail!("error saving image {p:?}"),
|
||||
};
|
||||
image.save(p).map_err(candle::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use parquet::file::reader::SerializedFileReader;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
async fn book_hub_1() {
|
||||
// ANCHOR: book_hub_1
|
||||
use candle::Device;
|
||||
use hf_hub::api::tokio::Api;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
|
||||
let weights_filename = repo.get("model.safetensors").await.unwrap();
|
||||
|
||||
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_1
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_2() {
|
||||
// ANCHOR: book_hub_2
|
||||
use candle::Device;
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_2
|
||||
assert_eq!(weights.len(), 206);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_hub_3() {
|
||||
// ANCHOR: book_hub_3
|
||||
use candle::{DType, Device, Tensor};
|
||||
use hf_hub::api::sync::Api;
|
||||
use memmap2::Mmap;
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs;
|
||||
|
||||
let api = Api::new().unwrap();
|
||||
let repo = api.model("bert-base-uncased".to_string());
|
||||
let weights_filename = repo.get("model.safetensors").unwrap();
|
||||
|
||||
let file = fs::File::open(weights_filename).unwrap();
|
||||
let mmap = unsafe { Mmap::map(&file).unwrap() };
|
||||
|
||||
// Use safetensors directly
|
||||
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
|
||||
let view = tensors
|
||||
.tensor("bert.encoder.layer.0.attention.self.query.weight")
|
||||
.unwrap();
|
||||
|
||||
// We're going to load shard with rank 1, within a world_size of 4
|
||||
// We're going to split along dimension 0 doing VIEW[start..stop, :]
|
||||
let rank = 1;
|
||||
let world_size = 4;
|
||||
let dim = 0;
|
||||
let dtype = view.dtype();
|
||||
let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = view.slice(start..stop).unwrap();
|
||||
|
||||
tp_shape[dim] = block_size;
|
||||
|
||||
// Convert safetensors Dtype to candle DType
|
||||
let dtype: DType = dtype.try_into().unwrap();
|
||||
|
||||
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
|
||||
// ANCHOR_END: book_hub_3
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
let dataset_id = "mnist".to_string();
|
||||
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(
|
||||
dataset_id,
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let repo = api.repo(repo);
|
||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// Ignore unused
|
||||
let _train = train_parquet;
|
||||
// ANCHOR: book_training_2
|
||||
for row in test_parquet {
|
||||
for (idx, (name, field)) in row?.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet")?;
|
||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet")?;
|
||||
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)?;
|
||||
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)?;
|
||||
// ANCHOR: book_training_3
|
||||
|
||||
let test_samples = 10_000;
|
||||
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
|
||||
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
|
||||
for row in test_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
test_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
test_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
|
||||
|
||||
let train_samples = 60_000;
|
||||
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
|
||||
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
|
||||
for row in train_parquet{
|
||||
for (_name, field) in row?.get_column_iter() {
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
train_buffer_images.extend(image.to_luma8().as_raw());
|
||||
}
|
||||
}
|
||||
}else if let parquet::record::Field::Long(label) = field {
|
||||
train_buffer_labels.push(*label as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
|
||||
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
|
||||
|
||||
let mnist = candle_datasets::vision::Dataset {
|
||||
train_images,
|
||||
train_labels,
|
||||
test_images,
|
||||
test_labels,
|
||||
labels: 10,
|
||||
};
|
||||
|
||||
// ANCHOR_END: book_training_3
|
||||
assert_eq!(mnist.test_images.dims(), &[10_000, 784]);
|
||||
assert_eq!(mnist.test_labels.dims(), &[10_000]);
|
||||
assert_eq!(mnist.train_images.dims(), &[60_000, 784]);
|
||||
assert_eq!(mnist.train_labels.dims(), &[60_000]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user