mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add a flag to save the trained weights. (#279)
This commit is contained in:
@ -52,6 +52,27 @@ impl st::View for Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl st::View for &Tensor {
|
||||||
|
fn dtype(&self) -> st::Dtype {
|
||||||
|
(*self).dtype().into()
|
||||||
|
}
|
||||||
|
fn shape(&self) -> &[usize] {
|
||||||
|
self.dims()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data(&self) -> Cow<[u8]> {
|
||||||
|
// This copies data from GPU to CPU.
|
||||||
|
// TODO: Avoid the unwrap here.
|
||||||
|
Cow::Owned(convert_back(self).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_len(&self) -> usize {
|
||||||
|
let n: usize = self.dims().iter().product();
|
||||||
|
let bytes_per_element = (*self).dtype().size_in_bytes();
|
||||||
|
n * bytes_per_element
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
||||||
&self,
|
&self,
|
||||||
|
@ -15,6 +15,7 @@ candle = { path = "../candle-core" }
|
|||||||
candle-nn = { path = "../candle-nn" }
|
candle-nn = { path = "../candle-nn" }
|
||||||
candle-transformers = { path = "../candle-transformers" }
|
candle-transformers = { path = "../candle-transformers" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
|
||||||
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
@ -76,6 +76,13 @@ impl VarStore {
|
|||||||
.map(|c| c.clone())
|
.map(|c| c.clone())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
|
||||||
|
let tensor_data = self.data.lock().unwrap();
|
||||||
|
let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor()));
|
||||||
|
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
|
||||||
@ -138,6 +145,7 @@ impl Model for Mlp {
|
|||||||
fn training_loop<M: Model>(
|
fn training_loop<M: Model>(
|
||||||
m: candle_nn::vision::Dataset,
|
m: candle_nn::vision::Dataset,
|
||||||
learning_rate: f64,
|
learning_rate: f64,
|
||||||
|
save: Option<String>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let dev = candle::Device::cuda_if_available(0)?;
|
let dev = candle::Device::cuda_if_available(0)?;
|
||||||
|
|
||||||
@ -176,6 +184,10 @@ fn training_loop<M: Model>(
|
|||||||
100. * test_accuracy
|
100. * test_accuracy
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
if let Some(save) = save {
|
||||||
|
println!("saving trained weights in {save}");
|
||||||
|
vs.save(&save)?
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,6 +204,10 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
learning_rate: Option<f64>,
|
learning_rate: Option<f64>,
|
||||||
|
|
||||||
|
/// The file where to save the trained weights, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
save: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
@ -204,7 +220,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
println!("test-labels: {:?}", m.test_labels.shape());
|
println!("test-labels: {:?}", m.test_labels.shape());
|
||||||
|
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
|
WhichModel::Linear => {
|
||||||
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
|
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.save)
|
||||||
|
}
|
||||||
|
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.save),
|
||||||
}
|
}
|
||||||
}
|
}
|
Reference in New Issue
Block a user