Add a flag to save the trained weights. (#279)

This commit is contained in:
Laurent Mazare
2023-07-30 15:41:42 +01:00
committed by GitHub
parent ba2254556c
commit 38ff693af0
3 changed files with 42 additions and 2 deletions

View File

@ -52,6 +52,27 @@ impl st::View for Tensor {
}
}
impl st::View for &Tensor {
fn dtype(&self) -> st::Dtype {
(*self).dtype().into()
}
fn shape(&self) -> &[usize] {
self.dims()
}
fn data(&self) -> Cow<[u8]> {
// This copies data from GPU to CPU.
// TODO: Avoid the unwrap here.
Cow::Owned(convert_back(self).unwrap())
}
fn data_len(&self) -> usize {
let n: usize = self.dims().iter().product();
let bytes_per_element = (*self).dtype().size_in_bytes();
n * bytes_per_element
}
}
impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,

View File

@ -15,6 +15,7 @@ candle = { path = "../candle-core" }
candle-nn = { path = "../candle-nn" }
candle-transformers = { path = "../candle-transformers" }
candle-flash-attn = { path = "../candle-flash-attn", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }

View File

@ -76,6 +76,13 @@ impl VarStore {
.map(|c| c.clone())
.collect::<Vec<_>>()
}
fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor()));
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
Ok(())
}
}
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
@ -138,6 +145,7 @@ impl Model for Mlp {
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
learning_rate: f64,
save: Option<String>,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
@ -176,6 +184,10 @@ fn training_loop<M: Model>(
100. * test_accuracy
);
}
if let Some(save) = save {
println!("saving trained weights in {save}");
vs.save(&save)?
}
Ok(())
}
@ -192,6 +204,10 @@ struct Args {
#[arg(long)]
learning_rate: Option<f64>,
/// The file where to save the trained weights, in safetensors format.
#[arg(long)]
save: Option<String>,
}
pub fn main() -> anyhow::Result<()> {
@ -204,7 +220,9 @@ pub fn main() -> anyhow::Result<()> {
println!("test-labels: {:?}", m.test_labels.shape());
match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
WhichModel::Linear => {
training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.), args.save)
}
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01), args.save),
}
}