// This should rearch 91.5% accuracy. #[cfg(feature = "mkl")] extern crate intel_mkl_src; use anyhow::Result; use candle::{DType, Var, D}; use candle_nn::{loss, ops}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; pub fn main() -> Result<()> { let dev = candle::Device::cuda_if_available(0)?; let m = candle_nn::vision::mnist::load_dir("data")?; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); println!("test-images: {:?}", m.test_images.shape()); println!("test-labels: {:?}", m.test_labels.shape()); let train_labels = m.train_labels; let train_images = m.train_images; let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?; let bs = Var::zeros(LABELS, DType::F32, &dev)?; let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0); let test_images = m.test_images; let test_labels = m.test_labels.to_dtype(DType::U32)?; for epoch in 1..200 { let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?; let log_sm = ops::log_softmax(&logits, D::Minus1)?; let loss = loss::nll(&log_sm, &train_labels)?; sgd.backward_step(&loss)?; let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?; let sum_ok = test_logits .argmax(D::Minus1)? .eq(&test_labels)? .to_dtype(DType::F32)? .sum_all()? .to_scalar::()?; let test_accuracy = sum_ok / test_labels.dims1()? as f32; println!( "{epoch:4} train loss: {:8.5} test acc: {:5.2}%", loss.to_scalar::()?, 100. * test_accuracy ); } Ok(()) }