diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index dee57b37..0e1cc655 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -52,6 +52,27 @@ impl st::View for Tensor { } } +impl st::View for &Tensor { + fn dtype(&self) -> st::Dtype { + (*self).dtype().into() + } + fn shape(&self) -> &[usize] { + self.dims() + } + + fn data(&self) -> Cow<[u8]> { + // This copies data from GPU to CPU. + // TODO: Avoid the unwrap here. + Cow::Owned(convert_back(self).unwrap()) + } + + fn data_len(&self) -> usize { + let n: usize = self.dims().iter().product(); + let bytes_per_element = (*self).dtype().size_in_bytes(); + n * bytes_per_element + } +} + impl Tensor { pub fn save_safetensors>( &self, diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 5a04130f..ff28c646 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -15,6 +15,7 @@ candle = { path = "../candle-core" } candle-nn = { path = "../candle-nn" } candle-transformers = { path = "../candle-transformers" } candle-flash-attn = { path = "../candle-flash-attn", optional = true } +safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } num-traits = { workspace = true } diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/mnist-training/main.rs similarity index 89% rename from candle-examples/examples/simple-training/main.rs rename to candle-examples/examples/mnist-training/main.rs index b78d937b..bdf28e5d 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -76,6 +76,13 @@ impl VarStore { .map(|c| c.clone()) .collect::>() } + + fn save>(&self, path: P) -> Result<()> { + let tensor_data = self.data.lock().unwrap(); + let data = tensor_data.tensors.iter().map(|(k, v)| (k, v.as_tensor())); + safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + Ok(()) + } } fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result { @@ -138,6 +145,7 @@ impl Model for Mlp { fn training_loop( m: candle_nn::vision::Dataset, learning_rate: f64, + save: Option, ) -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; @@ -176,6 +184,10 @@ fn training_loop( 100. * test_accuracy ); } + if let Some(save) = save { + println!("saving trained weights in {save}"); + vs.save(&save)? + } Ok(()) } @@ -192,6 +204,10 @@ struct Args { #[arg(long)] learning_rate: Option, + + /// The file where to save the trained weights, in safetensors format. + #[arg(long)] + save: Option, } pub fn main() -> anyhow::Result<()> { @@ -204,7 +220,9 @@ pub fn main() -> anyhow::Result<()> { println!("test-labels: {:?}", m.test_labels.shape()); match args.model { - WhichModel::Linear => training_loop::(m, args.learning_rate.unwrap_or(1.)), - WhichModel::Mlp => training_loop::(m, args.learning_rate.unwrap_or(0.01)), + WhichModel::Linear => { + training_loop::(m, args.learning_rate.unwrap_or(1.), args.save) + } + WhichModel::Mlp => training_loop::(m, args.learning_rate.unwrap_or(0.01), args.save), } }