Load a trained checkpoint in the mnist example. (#280)

This commit is contained in:
Laurent Mazare
2023-07-30 17:01:45 +01:00
committed by GitHub
parent 38ff693af0
commit a8d8f9f206

View File

@ -83,6 +83,27 @@ impl VarStore {
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
Ok(()) Ok(())
} }
fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
use candle::safetensors::Load;
let path = path.as_ref();
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
let data = data.deserialize()?;
let mut tensor_data = self.data.lock().unwrap();
for (name, var) in tensor_data.tensors.iter_mut() {
match data.tensor(name) {
Ok(data) => {
let data: Tensor = data.load(var.device())?;
if let Err(err) = var.set(&data) {
candle::bail!("error setting {name} using data from {path:?}: {err}",)
}
}
Err(_) => candle::bail!("cannot find tensor for {name}"),
}
}
Ok(())
}
} }
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> { fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
@ -145,6 +166,7 @@ impl Model for Mlp {
fn training_loop<M: Model>( fn training_loop<M: Model>(
m: candle_nn::vision::Dataset, m: candle_nn::vision::Dataset,
learning_rate: f64, learning_rate: f64,
load: Option<String>,
save: Option<String>, save: Option<String>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?; let dev = candle::Device::cuda_if_available(0)?;
@ -156,9 +178,14 @@ fn training_loop<M: Model>(
.unsqueeze(1)? .unsqueeze(1)?
.to_device(&dev)?; .to_device(&dev)?;
let vs = VarStore::new(DType::F32, dev.clone()); let mut vs = VarStore::new(DType::F32, dev.clone());
let model = M::new(vs.clone())?; let model = M::new(vs.clone())?;
if let Some(load) = load {
println!("loading weights from {load}");
vs.load(&load)?
}
let all_vars = vs.all_vars(); let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>(); let all_vars = all_vars.iter().collect::<Vec<_>>();
let sgd = candle_nn::SGD::new(&all_vars, learning_rate); let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
@ -208,6 +235,10 @@ struct Args {
/// The file where to save the trained weights, in safetensors format. /// The file where to save the trained weights, in safetensors format.
#[arg(long)] #[arg(long)]
save: Option<String>, save: Option<String>,
/// The file where to load the trained weights from, in safetensors format.
#[arg(long)]
load: Option<String>,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
@ -221,8 +252,10 @@ pub fn main() -> anyhow::Result<()> {
match args.model { match args.model {
WhichModel::Linear => { WhichModel::Linear => {
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.save) training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.load, args.save)
}
WhichModel::Mlp => {
training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.load, args.save)
} }
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.save),
} }
} }