mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Better training+hub
This commit is contained in:
@ -6,12 +6,36 @@ start with the Hello world dataset of machine learning, MNIST.
|
|||||||
|
|
||||||
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
Let's start with downloading `MNIST` from [huggingface](https://huggingface.co/datasets/mnist).
|
||||||
|
|
||||||
|
This requires `candle-datasets` with the `hub` feature.
|
||||||
```rust
|
```bash
|
||||||
use candle_datasets::from_hub;
|
cargo add candle-datasets --features hub
|
||||||
|
cargo add hf-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
let dataset = from_hub("mnist")?;
|
```rust,ignore
|
||||||
|
{{#include ../../../candle-examples/src/lib.rs:book_training_1}}
|
||||||
```
|
```
|
||||||
|
|
||||||
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
This uses the standardized `parquet` files from the `refs/convert/parquet` branch on every dataset.
|
||||||
|
`files` is now a `Vec` of [`parquet::file::serialized_reader::SerializedFileReader`].
|
||||||
|
|
||||||
|
We can inspect the content of the files with:
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
{{#include ../../../candle-examples/src/lib.rs:book_training_2}}
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see something like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
Column id 1, name label, value 6
|
||||||
|
Column id 0, name image, value {bytes: [137, ....]
|
||||||
|
Column id 1, name label, value 8
|
||||||
|
Column id 0, name image, value {bytes: [137, ....]
|
||||||
|
```
|
||||||
|
|
||||||
|
So each row contains 2 columns (image, label) with image being saved as bytes.
|
||||||
|
Let's put them into a useful struct.
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ fn sibling_to_parquet(
|
|||||||
repo: &ApiRepo,
|
repo: &ApiRepo,
|
||||||
) -> Result<SerializedFileReader<File>, Error> {
|
) -> Result<SerializedFileReader<File>, Error> {
|
||||||
let local = repo.get(rfilename)?;
|
let local = repo.get(rfilename)?;
|
||||||
let file = File::open(&local)?;
|
let file = File::open(local)?;
|
||||||
let reader = SerializedFileReader::new(file)?;
|
let reader = SerializedFileReader::new(file)?;
|
||||||
Ok(reader)
|
Ok(reader)
|
||||||
}
|
}
|
||||||
|
@ -55,6 +55,8 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -150,4 +152,25 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
|||||||
assert_eq!(view.shape(), &[768, 768]);
|
assert_eq!(view.shape(), &[768, 768]);
|
||||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[rustfmt::skip]
|
||||||
|
#[test]
|
||||||
|
fn book_training_1() -> Result<()>{
|
||||||
|
// ANCHOR: book_training_1
|
||||||
|
use candle_datasets::hub::from_hub;
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
|
||||||
|
let api = Api::new()?;
|
||||||
|
let files = from_hub(&api, "mnist".to_string())?;
|
||||||
|
// ANCHOR_END: book_training_1
|
||||||
|
// ANCHOR: book_training_2
|
||||||
|
let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten();
|
||||||
|
for row in rows {
|
||||||
|
for (idx, (name, field)) in row.get_column_iter().enumerate() {
|
||||||
|
println!("Column id {idx}, name {name}, value {field}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// ANCHOR_END: book_training_2
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user