Re-enable local dir for mnist.

This commit is contained in:
Nicolas Patry
2023-08-15 15:52:37 +02:00
parent dd06d93d0b
commit d726484a6d
2 changed files with 21 additions and 9 deletions

View File

@ -2,7 +2,7 @@
//!
//! The files can be obtained from the following link:
//! <http://yann.lecun.com/exdb/mnist/>
use candle::{DType, Device, Result, Tensor};
use candle::{DType, Device, Error, Result, Tensor};
use hf_hub::{api::sync::Api, Repo, RepoType};
use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
@ -92,7 +92,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
}
pub fn load() -> Result<crate::vision::Dataset> {
let api = Api::new().unwrap();
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let dataset_id = "mnist".to_string();
let repo = Repo::with_revision(
dataset_id,
@ -100,12 +100,16 @@ pub fn load() -> Result<crate::vision::Dataset> {
"refs/convert/parquet".to_string(),
);
let repo = api.repo(repo);
let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap();
let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap();
let test_parquet =
SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap();
let train_parquet =
SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap();
let test_parquet_filename = repo
.get("mnist/mnist-test.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let train_parquet_filename = repo
.get("mnist/mnist-train.parquet")
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
let (test_images, test_labels) = load_parquet(test_parquet)?;
let (train_images, train_labels) = load_parquet(train_parquet)?;
Ok(crate::vision::Dataset {