mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Load a trained checkpoint in the mnist example. (#280)
This commit is contained in:
@ -83,6 +83,27 @@ impl VarStore {
|
|||||||
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
|
||||||
|
use candle::safetensors::Load;
|
||||||
|
|
||||||
|
let path = path.as_ref();
|
||||||
|
let data = unsafe { candle::safetensors::MmapedFile::new(path)? };
|
||||||
|
let data = data.deserialize()?;
|
||||||
|
let mut tensor_data = self.data.lock().unwrap();
|
||||||
|
for (name, var) in tensor_data.tensors.iter_mut() {
|
||||||
|
match data.tensor(name) {
|
||||||
|
Ok(data) => {
|
||||||
|
let data: Tensor = data.load(var.device())?;
|
||||||
|
if let Err(err) = var.set(&data) {
|
||||||
|
candle::bail!("error setting {name} using data from {path:?}: {err}",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => candle::bail!("cannot find tensor for {name}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||||
@ -145,6 +166,7 @@ impl Model for Mlp {
|
|||||||
fn training_loop<M: Model>(
|
fn training_loop<M: Model>(
|
||||||
m: candle_nn::vision::Dataset,
|
m: candle_nn::vision::Dataset,
|
||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
|
load: Option<String>,
|
||||||
save: Option<String>,
|
save: Option<String>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let dev = candle::Device::cuda_if_available(0)?;
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
@ -156,9 +178,14 @@ fn training_loop<M: Model>(
|
|||||||
.unsqueeze(1)?
|
.unsqueeze(1)?
|
||||||
.to_device(&dev)?;
|
.to_device(&dev)?;
|
||||||
|
|
||||||
let vs = VarStore::new(DType::F32, dev.clone());
|
let mut vs = VarStore::new(DType::F32, dev.clone());
|
||||||
let model = M::new(vs.clone())?;
|
let model = M::new(vs.clone())?;
|
||||||
|
|
||||||
|
if let Some(load) = load {
|
||||||
|
println!("loading weights from {load}");
|
||||||
|
vs.load(&load)?
|
||||||
|
}
|
||||||
|
|
||||||
let all_vars = vs.all_vars();
|
let all_vars = vs.all_vars();
|
||||||
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
let all_vars = all_vars.iter().collect::<Vec<_>>();
|
||||||
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
|
||||||
@ -208,6 +235,10 @@ struct Args {
|
|||||||
/// The file where to save the trained weights, in safetensors format.
|
/// The file where to save the trained weights, in safetensors format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
save: Option<String>,
|
save: Option<String>,
|
||||||
|
|
||||||
|
/// The file where to load the trained weights from, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
load: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
@ -221,8 +252,10 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::Linear => {
|
WhichModel::Linear => {
|
||||||
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.save)
|
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.load, args.save)
|
||||||
|
}
|
||||||
|
WhichModel::Mlp => {
|
||||||
|
training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.load, args.save)
|
||||||
}
|
}
|
||||||
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.save),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user