mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Better training+hub
This commit is contained in:
@ -55,6 +55,8 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
|
||||
#[rustfmt::skip]
|
||||
#[tokio::test]
|
||||
@ -150,4 +152,25 @@ let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).un
|
||||
assert_eq!(view.shape(), &[768, 768]);
|
||||
assert_eq!(tp_tensor.dims(), &[192, 768]);
|
||||
}
|
||||
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn book_training_1() -> Result<()>{
|
||||
// ANCHOR: book_training_1
|
||||
use candle_datasets::hub::from_hub;
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
let api = Api::new()?;
|
||||
let files = from_hub(&api, "mnist".to_string())?;
|
||||
// ANCHOR_END: book_training_1
|
||||
// ANCHOR: book_training_2
|
||||
let rows = files.into_iter().flat_map(|r| r.into_iter()).flatten();
|
||||
for row in rows {
|
||||
for (idx, (name, field)) in row.get_column_iter().enumerate() {
|
||||
println!("Column id {idx}, name {name}, value {field}");
|
||||
}
|
||||
}
|
||||
// ANCHOR_END: book_training_2
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user