mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Re-enable local dir for mnist.
This commit is contained in:
@ -2,7 +2,7 @@
|
|||||||
//!
|
//!
|
||||||
//! The files can be obtained from the following link:
|
//! The files can be obtained from the following link:
|
||||||
//! <http://yann.lecun.com/exdb/mnist/>
|
//! <http://yann.lecun.com/exdb/mnist/>
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Error, Result, Tensor};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
@ -92,7 +92,7 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn load() -> Result<crate::vision::Dataset> {
|
pub fn load() -> Result<crate::vision::Dataset> {
|
||||||
let api = Api::new().unwrap();
|
let api = Api::new().map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
let dataset_id = "mnist".to_string();
|
let dataset_id = "mnist".to_string();
|
||||||
let repo = Repo::with_revision(
|
let repo = Repo::with_revision(
|
||||||
dataset_id,
|
dataset_id,
|
||||||
@ -100,12 +100,16 @@ pub fn load() -> Result<crate::vision::Dataset> {
|
|||||||
"refs/convert/parquet".to_string(),
|
"refs/convert/parquet".to_string(),
|
||||||
);
|
);
|
||||||
let repo = api.repo(repo);
|
let repo = api.repo(repo);
|
||||||
let test_parquet_filename = repo.get("mnist/mnist-test.parquet").unwrap();
|
let test_parquet_filename = repo
|
||||||
let train_parquet_filename = repo.get("mnist/mnist-train.parquet").unwrap();
|
.get("mnist/mnist-test.parquet")
|
||||||
let test_parquet =
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?).unwrap();
|
let train_parquet_filename = repo
|
||||||
let train_parquet =
|
.get("mnist/mnist-train.parquet")
|
||||||
SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?).unwrap();
|
.map_err(|e| Error::Msg(format!("Api error: {e}")))?;
|
||||||
|
let test_parquet = SerializedFileReader::new(std::fs::File::open(test_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
|
let train_parquet = SerializedFileReader::new(std::fs::File::open(train_parquet_filename)?)
|
||||||
|
.map_err(|e| Error::Msg(format!("Parquet error: {e}")))?;
|
||||||
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
let (test_images, test_labels) = load_parquet(test_parquet)?;
|
||||||
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
let (train_images, train_labels) = load_parquet(train_parquet)?;
|
||||||
Ok(crate::vision::Dataset {
|
Ok(crate::vision::Dataset {
|
||||||
|
@ -138,12 +138,20 @@ struct Args {
|
|||||||
/// The file where to load the trained weights from, in safetensors format.
|
/// The file where to load the trained weights from, in safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
load: Option<String>,
|
load: Option<String>,
|
||||||
|
|
||||||
|
/// The file where to load the trained weights from, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
local_mnist: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Load the dataset
|
// Load the dataset
|
||||||
let m = candle_datasets::vision::mnist::load()?;
|
let m = if let Some(directory) = args.local_mnist {
|
||||||
|
candle_datasets::vision::mnist::load_dir(directory)?
|
||||||
|
} else {
|
||||||
|
candle_datasets::vision::mnist::load()?
|
||||||
|
};
|
||||||
println!("train-images: {:?}", m.train_images.shape());
|
println!("train-images: {:?}", m.train_images.shape());
|
||||||
println!("train-labels: {:?}", m.train_labels.shape());
|
println!("train-labels: {:?}", m.train_labels.shape());
|
||||||
println!("test-images: {:?}", m.test_images.shape());
|
println!("test-images: {:?}", m.test_images.shape());
|
||||||
|
Reference in New Issue
Block a user