mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a flag to save the trained weights. (#279)
This commit is contained in:
@ -52,6 +52,27 @@ impl st::View for Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
impl st::View for &Tensor {
|
||||
fn dtype(&self) -> st::Dtype {
|
||||
(*self).dtype().into()
|
||||
}
|
||||
fn shape(&self) -> &[usize] {
|
||||
self.dims()
|
||||
}
|
||||
|
||||
fn data(&self) -> Cow<[u8]> {
|
||||
// This copies data from GPU to CPU.
|
||||
// TODO: Avoid the unwrap here.
|
||||
Cow::Owned(convert_back(self).unwrap())
|
||||
}
|
||||
|
||||
fn data_len(&self) -> usize {
|
||||
let n: usize = self.dims().iter().product();
|
||||
let bytes_per_element = (*self).dtype().size_in_bytes();
|
||||
n * bytes_per_element
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
|
@ -15,6 +15,7 @@ candle = { path = "../candle-core" }
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
candle-transformers = { path = "../candle-transformers" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
|
@ -76,6 +76,13 @@ impl VarStore {
|
||||
.map(|c| c.clone())
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
|
||||
let tensor_data = self.data.lock().unwrap();
|
||||
let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor()));
|
||||
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||
@ -138,6 +145,7 @@ impl Model for Mlp {
|
||||
fn training_loop<M: Model>(
|
||||
m: candle_nn::vision::Dataset,
|
||||
learning_rate: f64,
|
||||
save: Option<String>,
|
||||
) -> anyhow::Result<()> {
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
|
||||
@ -176,6 +184,10 @@ fn training_loop<M: Model>(
|
||||
100. * test_accuracy
|
||||
);
|
||||
}
|
||||
if let Some(save) = save {
|
||||
println!("saving trained weights in {save}");
|
||||
vs.save(&save)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -192,6 +204,10 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
learning_rate: Option<f64>,
|
||||
|
||||
/// The file where to save the trained weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
save: Option<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
@ -204,7 +220,9 @@ pub fn main() -> anyhow::Result<()> {
|
||||
println!("test-labels: {:?}", m.test_labels.shape());
|
||||
|
||||
match args.model {
|
||||
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
|
||||
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
|
||||
WhichModel::Linear => {
|
||||
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.save)
|
||||
}
|
||||
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.save),
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user