diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs index c908412c..1085edd6 100644 --- a/candle-datasets/src/vision/mnist.rs +++ b/candle-datasets/src/vision/mnist.rs @@ -2,7 +2,7 @@ //! //! The files can be obtained from the following link: //! -use candle::{DType, Device, Result, Tensor}; +use candle::{DType, Device, Error, Result, Tensor}; use hf_hub::{api::sync::Api, Repo, RepoType}; use parquet::file::reader::{FileReader, SerializedFileReader}; use std::fs::File; @@ -92,7 +92,7 @@ fn load_parquet(parquet: SerializedFileReader) -> Result<(Tensor, } pub fn load() -> Result { - let api = Api::new().unwrap(); + let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?; let dataset_id = "mnist".to_string(); let repo = Repo::with_revision( dataset_id, @@ -100,12 +100,16 @@ pub fn load() -> Result { "refs/convert/parquet".to_string(), ); let repo = api.repo(repo); - let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap(); - let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap(); - let test_parquet = - SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap(); - let train_parquet = - SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap(); + let test_parquet_filename = repo + .get("mnist/mnist-test.parquet") + .map_err(|e| Error::Msg(format!("Api error: {e}")))?; + let train_parquet_filename = repo + .get("mnist/mnist-train.parquet") + .map_err(|e| Error::Msg(format!("Api error: {e}")))?; + let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?) + .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?; + let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?) + .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?; let (test_images, test_labels) = load_parquet(test_parquet)?; let (train_images, train_labels) = load_parquet(train_parquet)?; Ok(crate::vision::Dataset { diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index 2cdad9fa..ea4b4bce 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -138,12 +138,20 @@ struct Args { /// The file where to load the trained weights from, in safetensors format. #[arg(long)] load: Option, + + /// The file where to load the trained weights from, in safetensors format. + #[arg(long)] + local_mnist: Option, } pub fn main() -> anyhow::Result<()> { let args = Args::parse(); // Load the dataset - let m = candle_datasets::vision::mnist::load()?; + let m = if let Some(directory) = args.local_mnist { + candle_datasets::vision::mnist::load_dir(directory)? + } else { + candle_datasets::vision::mnist::load()? + }; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); println!("test-images: {:?}", m.test_images.shape());