Merge pull request #439 from huggingface/training_hub_dataset

[Book] Add small error management + start training (with generic dataset inclusion).
This commit is contained in:
Nicolas Patry
2023-08-29 13:10:05 +02:00
committed by GitHub
14 changed files with 444 additions and 110 deletions

View File

@ -138,12 +138,20 @@ struct Args {
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
load: Option<String>,
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
local_mnist: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_datasets::vision::mnist::load_dir("data")?;
let m = if let Some(directory) = args.local_mnist {
candle_datasets::vision::mnist::load_dir(directory)?
} else {
candle_datasets::vision::mnist::load()?
};
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());