From a8d8f9f20601b30124d1c5096e3ad276afc99bf8 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 30 Jul 2023 17:01:45 +0100 Subject: [PATCH] Load a trained checkpoint in the mnist example. (#280) --- .../examples/mnist-training/main.rs | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index bdf28e5d..d50cb944 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -83,6 +83,27 @@ impl VarStore { safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; Ok(()) } + + fn load>(&mut self, path: P) -> Result<()> { + use candle::safetensors::Load; + + let path = path.as_ref(); + let data = unsafe { candle::safetensors::MmapedFile::new(path)? }; + let data = data.deserialize()?; + let mut tensor_data = self.data.lock().unwrap(); + for (name, var) in tensor_data.tensors.iter_mut() { + match data.tensor(name) { + Ok(data) => { + let data: Tensor = data.load(var.device())?; + if let Err(err) = var.set(&data) { + candle::bail!("error setting {name} using data from {path:?}: {err}",) + } + } + Err(_) => candle::bail!("cannot find tensor for {name}"), + } + } + Ok(()) + } } fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { @@ -145,6 +166,7 @@ impl Model for Mlp { fn training_loop( m: candle_nn::vision::Dataset, learning_rate: f64, + load: Option, save: Option, ) -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; @@ -156,9 +178,14 @@ fn training_loop( .unsqueeze(1)? .to_device(&dev)?; - let vs = VarStore::new(DType::F32, dev.clone()); + let mut vs = VarStore::new(DType::F32, dev.clone()); let model = M::new(vs.clone())?; + if let Some(load) = load { + println!("loading weights from {load}"); + vs.load(&load)? + } + let all_vars = vs.all_vars(); let all_vars = all_vars.iter().collect::>(); let sgd = candle_nn::SGD::new(&all_vars, learning_rate); @@ -208,6 +235,10 @@ struct Args { /// The file where to save the trained weights, in safetensors format. #[arg(long)] save: Option, + + /// The file where to load the trained weights from, in safetensors format. + #[arg(long)] + load: Option, } pub fn main() -> anyhow::Result<()> { @@ -221,8 +252,10 @@ pub fn main() -> anyhow::Result<()> { match args.model { WhichModel::Linear => { - training_loop::(m, args.learning_rate.unwrap_or(1.), args.save) + training_loop::(m, args.learning_rate.unwrap_or(1.), args.load, args.save) + } + WhichModel::Mlp => { + training_loop::(m, args.learning_rate.unwrap_or(0.01), args.load, args.save) } - WhichModel::Mlp => training_loop::(m, args.learning_rate.unwrap_or(0.01), args.save), } }