From ded93a116983da7c84ea224a6191bcbc3a7fdef1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 13 Jul 2023 19:05:44 +0100 Subject: [PATCH] Add the SGD optimizer (#160) * Add the nn::optim and some conversion traits. * Add the backward_step function for SGD. * Get the SGD optimizer to work and add a test. * Make the test slighly simpler. --- candle-core/src/convert.rs | 96 +++++++++++++++++++++++++++++++++++++ candle-core/src/lib.rs | 1 + candle-core/src/variable.rs | 7 ++- candle-nn/src/lib.rs | 2 + candle-nn/src/optim.rs | 47 ++++++++++++++++++ candle-nn/tests/optim.rs | 19 ++++++++ 6 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 candle-core/src/convert.rs create mode 100644 candle-nn/src/optim.rs create mode 100644 candle-nn/tests/optim.rs diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs new file mode 100644 index 00000000..41a9c4ee --- /dev/null +++ b/candle-core/src/convert.rs @@ -0,0 +1,96 @@ +//! Implement conversion traits for tensors +use crate::{Device, Error, Tensor, WithDType}; +use half::{bf16, f16}; +use std::convert::TryFrom; + +impl TryFrom<&Tensor> for Vec { + type Error = Error; + fn try_from(tensor: &Tensor) -> Result { + tensor.to_vec1::() + } +} + +impl TryFrom<&Tensor> for Vec> { + type Error = Error; + fn try_from(tensor: &Tensor) -> Result { + tensor.to_vec2::() + } +} + +impl TryFrom<&Tensor> for Vec>> { + type Error = Error; + fn try_from(tensor: &Tensor) -> Result { + tensor.to_vec3::() + } +} + +impl TryFrom for Vec { + type Error = Error; + fn try_from(tensor: Tensor) -> Result { + Vec::::try_from(&tensor) + } +} + +impl TryFrom for Vec> { + type Error = Error; + fn try_from(tensor: Tensor) -> Result { + Vec::>::try_from(&tensor) + } +} + +impl TryFrom for Vec>> { + type Error = Error; + fn try_from(tensor: Tensor) -> Result { + Vec::>>::try_from(&tensor) + } +} + +impl TryFrom<&[T]> for Tensor { + type Error = Error; + fn try_from(v: &[T]) -> Result { + Tensor::from_slice(v, v.len(), &Device::Cpu) + } +} + +impl TryFrom> for Tensor { + type Error = Error; + fn try_from(v: Vec) -> Result { + let len = v.len(); + Tensor::from_vec(v, len, &Device::Cpu) + } +} + +macro_rules! from_tensor { + ($typ:ident) => { + impl TryFrom<&Tensor> for $typ { + type Error = Error; + + fn try_from(tensor: &Tensor) -> Result { + tensor.to_scalar::<$typ>() + } + } + + impl TryFrom for $typ { + type Error = Error; + + fn try_from(tensor: Tensor) -> Result { + $typ::try_from(&tensor) + } + } + + impl TryFrom<$typ> for Tensor { + type Error = Error; + + fn try_from(v: $typ) -> Result { + Tensor::new(v, &Device::Cpu) + } + } + }; +} + +from_tensor!(f64); +from_tensor!(f32); +from_tensor!(f16); +from_tensor!(bf16); +from_tensor!(u32); +from_tensor!(u8); diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 254e2c99..f11bad6e 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -36,6 +36,7 @@ mod backend; mod backprop; mod conv; +mod convert; mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index b9051ed6..0ae16c64 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -1,13 +1,12 @@ // Variables are wrappers around tensors that can be modified, they are typically used for holding // weights and being modified by gradient descent. -// They are not cloneable by default to avoid having too many potential writers on the data. -// We also do not expose a public way to create variables as this would break the invariant that -// the tensor within a variable is actually with `is_variable` set to `true`. +// We do not expose a public way to create variables as this would break the invariant that the +// tensor within a variable is actually with `is_variable` set to `true`. use crate::{DType, Device, Error, Result, Shape, Tensor}; /// A variable is a wrapper around a tensor, however variables can have their content modified /// whereas tensors are immutable. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Var(Tensor); impl std::ops::Deref for Var { diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index bb168661..0eb2d8e1 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -6,6 +6,7 @@ pub mod embedding; pub mod init; pub mod layer_norm; pub mod linear; +pub mod optim; pub mod var_builder; pub use activation::Activation; @@ -13,4 +14,5 @@ pub use conv::{Conv1d, Conv1dConfig}; pub use embedding::Embedding; pub use layer_norm::LayerNorm; pub use linear::Linear; +pub use optim::SGD; pub use var_builder::VarBuilder; diff --git a/candle-nn/src/optim.rs b/candle-nn/src/optim.rs new file mode 100644 index 00000000..741c51dc --- /dev/null +++ b/candle-nn/src/optim.rs @@ -0,0 +1,47 @@ +//! Various optimization algorithms. +use candle::{Result, Tensor, Var}; + +#[derive(Debug)] +pub struct SGD { + vars: Vec, + learning_rate: f64, +} + +impl SGD { + pub fn new(vars: &[&Var], learning_rate: f64) -> Self { + let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect(); + Self { + vars, + learning_rate, + } + } + + pub fn empty(learning_rate: f64) -> Self { + Self { + vars: vec![], + learning_rate, + } + } + + pub fn into_inner(self) -> Vec { + self.vars + } + + pub fn learning_rate(&self) -> f64 { + self.learning_rate + } + + pub fn push(&mut self, var: &Var) { + self.vars.push(var.clone()) + } + + pub fn backward_step(&self, loss: &Tensor) -> Result<()> { + let grads = loss.backward()?; + for var in self.vars.iter() { + if let Some(grad) = grads.get(var) { + var.set(&var.sub(&(grad * self.learning_rate)?)?)? + } + } + Ok(()) + } +} diff --git a/candle-nn/tests/optim.rs b/candle-nn/tests/optim.rs new file mode 100644 index 00000000..29aa987b --- /dev/null +++ b/candle-nn/tests/optim.rs @@ -0,0 +1,19 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use candle::{Device, Var}; +use candle_nn::SGD; + +#[test] +fn sgd_optim() -> Result<()> { + let x = Var::new(0f32, &Device::Cpu)?; + let sgd = SGD::new(&[&x], 0.1); + let xt = x.as_tensor(); + for _step in 0..100 { + let loss = ((xt - 4.2)? * (xt - 4.2)?)?; + sgd.backward_step(&loss)? + } + assert_eq!(x.to_scalar::()?, 4.199999); + Ok(()) +}