// This should reach 91.5% accuracy. #[cfg(feature = "mkl")] extern crate intel_mkl_src; use clap::{Parser, ValueEnum}; use candle::{DType, Device, Result, Shape, Tensor, Var, D}; use candle_nn::{loss, ops, Init, Linear}; use std::sync::{Arc, Mutex}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; struct TensorData { tensors: std::collections::HashMap, pub dtype: DType, pub device: Device, } // A variant of candle_nn::VarBuilder for initializing variables before training. #[derive(Clone)] struct VarStore { data: Arc>, path: Vec, } impl VarStore { fn new(dtype: DType, device: Device) -> Self { let data = TensorData { tensors: std::collections::HashMap::new(), dtype, device, }; Self { data: Arc::new(Mutex::new(data)), path: vec![], } } fn pp(&self, s: &str) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); Self { data: self.data.clone(), path, } } fn get>(&self, shape: S, tensor_name: &str, init: Init) -> Result { let shape = shape.into(); let path = if self.path.is_empty() { tensor_name.to_string() } else { [&self.path.join("."), tensor_name].join(".") }; let mut tensor_data = self.data.lock().unwrap(); if let Some(tensor) = tensor_data.tensors.get(&path) { let tensor_shape = tensor.shape(); if &shape != tensor_shape { candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") } return Ok(tensor.as_tensor().clone()); } let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?; let tensor = var.as_tensor().clone(); tensor_data.tensors.insert(path, var); Ok(tensor) } fn all_vars(&self) -> Vec { let tensor_data = self.data.lock().unwrap(); #[allow(clippy::map_clone)] tensor_data .tensors .values() .map(|c| c.clone()) .collect::>() } fn save>(&self, path: P) -> Result<()> { let tensor_data = self.data.lock().unwrap(); let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor())); safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; Ok(()) } fn load>(&mut self, path: P) -> Result<()> { use candle::safetensors::Load; let path = path.as_ref(); let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; let data = data.deserialize()?; let mut tensor_data = self.data.lock().unwrap(); for (name, var) in tensor_data.tensors.iter_mut() { match data.tensor(name) { Ok(data) => { let data: Tensor = data.load(var.device())?; if let Err(err) = var.set(&data) { candle::bail!("error setting {name} using data from {path:?}: {err}",) } } Err(_) => candle::bail!("cannot find tensor for {name}"), } } Ok(()) } } fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?; Ok(Linear::new(ws, Some(bs))) } fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; let ws = vs.get((out_dim, in_dim), "weight", init_ws)?; let bound = 1. / (in_dim as f64).sqrt(); let init_bs = Init::Uniform { lo: -bound, up: bound, }; let bs = vs.get(out_dim, "bias", init_bs)?; Ok(Linear::new(ws, Some(bs))) } trait Model: Sized { fn new(vs: VarStore) -> Result; fn forward(&self, xs: &Tensor) -> Result; } struct LinearModel { linear: Linear, } impl Model for LinearModel { fn new(vs: VarStore) -> Result { let linear = linear_z(IMAGE_DIM, LABELS, vs)?; Ok(Self { linear }) } fn forward(&self, xs: &Tensor) -> Result { self.linear.forward(xs) } } struct Mlp { ln1: Linear, ln2: Linear, } impl Model for Mlp { fn new(vs: VarStore) -> Result { let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?; let ln2 = linear(100, LABELS, vs.pp("ln2"))?; Ok(Self { ln1, ln2 }) } fn forward(&self, xs: &Tensor) -> Result { let xs = self.ln1.forward(xs)?; let xs = xs.relu()?; self.ln2.forward(&xs) } } struct TrainingArgs { learning_rate: f64, load: Option, save: Option, epochs: usize, } fn training_loop( m: candle_nn::vision::Dataset, args: &TrainingArgs, ) -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; let train_labels = m.train_labels; let train_images = m.train_images.to_device(&dev)?; let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?; let mut vs = VarStore::new(DType::F32, dev.clone()); let model = M::new(vs.clone())?; if let Some(load) = &args.load { println!("loading weights from {load}"); vs.load(load)? } let all_vars = vs.all_vars(); let all_vars = all_vars.iter().collect::>(); let sgd = candle_nn::SGD::new(&all_vars, args.learning_rate); let test_images = m.test_images.to_device(&dev)?; let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?; for epoch in 1..args.epochs { let logits = model.forward(&train_images)?; let log_sm = ops::log_softmax(&logits, D::Minus1)?; let loss = loss::nll(&log_sm, &train_labels)?; sgd.backward_step(&loss)?; let test_logits = model.forward(&test_images)?; let sum_ok = test_logits .argmax(D::Minus1)? .eq(&test_labels)? .to_dtype(DType::F32)? .sum_all()? .to_scalar::()?; let test_accuracy = sum_ok / test_labels.dims1()? as f32; println!( "{epoch:4} train loss: {:8.5} test acc: {:5.2}%", loss.to_scalar::()?, 100. * test_accuracy ); } if let Some(save) = &args.save { println!("saving trained weights in {save}"); vs.save(save)? } Ok(()) } #[derive(ValueEnum, Clone)] enum WhichModel { Linear, Mlp, } #[derive(Parser)] struct Args { #[clap(value_enum, default_value_t = WhichModel::Linear)] model: WhichModel, #[arg(long)] learning_rate: Option, #[arg(long, default_value_t = 200)] epochs: usize, /// The file where to save the trained weights, in safetensors format. #[arg(long)] save: Option, /// The file where to load the trained weights from, in safetensors format. #[arg(long)] load: Option, } pub fn main() -> anyhow::Result<()> { let args = Args::parse(); // Load the dataset let m = candle_nn::vision::mnist::load_dir("data")?; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); println!("test-images: {:?}", m.test_images.shape()); println!("test-labels: {:?}", m.test_labels.shape()); let default_learning_rate = match args.model { WhichModel::Linear => 1., WhichModel::Mlp => 0.05, }; let training_args = TrainingArgs { epochs: args.epochs, learning_rate: args.learning_rate.unwrap_or(default_learning_rate), load: args.load, save: args.save, }; match args.model { WhichModel::Linear => training_loop::(m, &training_args), WhichModel::Mlp => training_loop::(m, &training_args), } }