mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Load a trained checkpoint in the mnist example. (#280)
This commit is contained in:
@ -83,6 +83,27 @@ impl VarStore {
|
||||
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
||||
use candle::safetensors::Load;
|
||||
|
||||
let path = path.as_ref();
|
||||
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
||||
let data = data.deserialize()?;
|
||||
let mut tensor_data = self.data.lock().unwrap();
|
||||
for (name, var) in tensor_data.tensors.iter_mut() {
|
||||
match data.tensor(name) {
|
||||
Ok(data) => {
|
||||
let data: Tensor = data.load(var.device())?;
|
||||
if let Err(err) = var.set(&data) {
|
||||
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||
}
|
||||
}
|
||||
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||
@ -145,6 +166,7 @@ impl Model for Mlp {
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_nn::vision::Dataset,
|
||||
learning_rate: f64,
|
||||
load: Option<String>,
|
||||
save: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
@ -156,9 +178,14 @@ fn training_loop<M: Model>(
|
||||
.unsqueeze(1)?
|
||||
.to_device(&dev)?;
|
||||
|
||||
let vs = VarStore::new(DType::F32, dev.clone());
|
||||
let mut vs = VarStore::new(DType::F32, dev.clone());
|
||||
let model = M::new(vs.clone())?;
|
||||
|
||||
if let Some(load) = load {
|
||||
println!("loading weights from {load}");
|
||||
vs.load(&load)?
|
||||
}
|
||||
|
||||
let all_vars = vs.all_vars();
|
||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
||||
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
||||
@ -208,6 +235,10 @@ struct Args {
|
||||
/// The file where to save the trained weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
save: Option<String>,
|
||||
|
||||
/// The file where to load the trained weights from, in safetensors format.
|
||||
#[arg(long)]
|
||||
load: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
@ -221,8 +252,10 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
match args.model {
|
||||
WhichModel::Linear => {
|
||||
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.save)
|
||||
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.load, args.save)
|
||||
}
|
||||
WhichModel::Mlp => {
|
||||
training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.load, args.save)
|
||||
}
|
||||
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.save),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user