diff --git a/candle-datasets/src/vision/mnist.rs b/candle-datasets/src/vision/mnist.rs
index c908412c..1085edd6 100644
--- a/candle-datasets/src/vision/mnist.rs
+++ b/candle-datasets/src/vision/mnist.rs
@@ -2,7 +2,7 @@
//!
//! The files can be obtained from the following link:
//!
-use candle::{DType, Device, Result, Tensor};
+use candle::{DType, Device, Error, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
@@ -92,7 +92,7 @@ fn load_parquet(parquet: SerializedFileReader) -> Result<(Tensor,
}
pub fn load() -> Result {
- let api = Api::new().unwrap();
+ let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
@@ -100,12 +100,16 @@ pub fn load() -> Result {
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
- let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap();
- let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap();
- let test_parquet =
- SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap();
- let train_parquet =
- SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap();
+ let test_parquet_filename = repo
+ .get("mnist/mnist-test.parquet")
+ .map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let train_parquet_filename = repo
+ .get("mnist/mnist-train.parquet")
+ .map_err(|e| Error::Msg(format!("Api error: {e}")))?;
+ let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
+ .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
+ let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
+ .map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index 2cdad9fa..ea4b4bce 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -138,12 +138,20 @@ struct Args {
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
load: Option,
+
+ /// The file where to load the trained weights from, in safetensors format.
+ #[arg(long)]
+ local_mnist: Option,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
- let m = candle_datasets::vision::mnist::load()?;
+ let m = if let Some(directory) = args.local_mnist {
+ candle_datasets::vision::mnist::load_dir(directory)?
+ } else {
+ candle_datasets::vision::mnist::load()?
+ };
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());