//! The MNIST hand-written digit dataset. //! //! The files can be obtained from the following link: //! use candle::{DType, Device, Result, Tensor}; use std::fs::File; use std::io::{self, BufReader, Read}; fn read_u32(reader: &mut T) -> Result { let mut b = vec![0u8; 4]; reader.read_exact(&mut b)?; let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| { (s + basis * u64::from(x), basis * 256) }); Ok(result as u32) } fn check_magic_number(reader: &mut T, expected: u32) -> Result<()> { let magic_number = read_u32(reader)?; if magic_number != expected { Err(io::Error::new( io::ErrorKind::Other, format!("incorrect magic number {magic_number} != {expected}"), ))?; } Ok(()) } fn read_labels(filename: &std::path::Path) -> Result { let mut buf_reader = BufReader::new(File::open(filename)?); check_magic_number(&mut buf_reader, 2049)?; let samples = read_u32(&mut buf_reader)?; let mut data = vec![0u8; samples as usize]; buf_reader.read_exact(&mut data)?; let samples = data.len(); Tensor::from_vec(data, samples, &Device::Cpu) } fn read_images(filename: &std::path::Path) -> Result { let mut buf_reader = BufReader::new(File::open(filename)?); check_magic_number(&mut buf_reader, 2051)?; let samples = read_u32(&mut buf_reader)? as usize; let rows = read_u32(&mut buf_reader)? as usize; let cols = read_u32(&mut buf_reader)? as usize; let data_len = samples * rows * cols; let mut data = vec![0u8; data_len]; buf_reader.read_exact(&mut data)?; let tensor = Tensor::from_vec(data, (samples, rows * cols), &Device::Cpu)?; tensor.to_dtype(DType::F32)? / 255. } pub fn load_dir>(dir: T) -> Result { let dir = dir.as_ref(); let train_images = read_images(&dir.join("train-images-idx3-ubyte"))?; let train_labels = read_labels(&dir.join("train-labels-idx1-ubyte"))?; let test_images = read_images(&dir.join("t10k-images-idx3-ubyte"))?; let test_labels = read_labels(&dir.join("t10k-labels-idx1-ubyte"))?; Ok(crate::vision::Dataset { train_images, train_labels, test_images, test_labels, labels: 10, }) }