mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the SGD optimizer (#160)
* Add the nn::optim and some conversion traits. * Add the backward_step function for SGD. * Get the SGD optimizer to work and add a test. * Make the test slighly simpler.
This commit is contained in:
96
candle-core/src/convert.rs
Normal file
96
candle-core/src/convert.rs
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
//! Implement conversion traits for tensors
|
||||||
|
use crate::{Device, Error, Tensor, WithDType};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
|
||||||
|
tensor.to_vec1::<T>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<T>> {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
|
||||||
|
tensor.to_vec2::<T>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
|
||||||
|
tensor.to_vec3::<T>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<Tensor> for Vec<T> {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
|
||||||
|
Vec::<T>::try_from(&tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<T>> {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
|
||||||
|
Vec::<Vec<T>>::try_from(&tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<Vec<T>>> {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
|
||||||
|
Vec::<Vec<Vec<T>>>::try_from(&tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<&[T]> for Tensor {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(v: &[T]) -> Result<Self, Self::Error> {
|
||||||
|
Tensor::from_slice(v, v.len(), &Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> TryFrom<Vec<T>> for Tensor {
|
||||||
|
type Error = Error;
|
||||||
|
fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
|
||||||
|
let len = v.len();
|
||||||
|
Tensor::from_vec(v, len, &Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! from_tensor {
|
||||||
|
($typ:ident) => {
|
||||||
|
impl TryFrom<&Tensor> for $typ {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
|
||||||
|
tensor.to_scalar::<$typ>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<Tensor> for $typ {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
|
||||||
|
$typ::try_from(&tensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<$typ> for Tensor {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn try_from(v: $typ) -> Result<Self, Self::Error> {
|
||||||
|
Tensor::new(v, &Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
from_tensor!(f64);
|
||||||
|
from_tensor!(f32);
|
||||||
|
from_tensor!(f16);
|
||||||
|
from_tensor!(bf16);
|
||||||
|
from_tensor!(u32);
|
||||||
|
from_tensor!(u8);
|
@ -36,6 +36,7 @@
|
|||||||
mod backend;
|
mod backend;
|
||||||
mod backprop;
|
mod backprop;
|
||||||
mod conv;
|
mod conv;
|
||||||
|
mod convert;
|
||||||
mod cpu_backend;
|
mod cpu_backend;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
mod cuda_backend;
|
mod cuda_backend;
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
// Variables are wrappers around tensors that can be modified, they are typically used for holding
|
// Variables are wrappers around tensors that can be modified, they are typically used for holding
|
||||||
// weights and being modified by gradient descent.
|
// weights and being modified by gradient descent.
|
||||||
// They are not cloneable by default to avoid having too many potential writers on the data.
|
// We do not expose a public way to create variables as this would break the invariant that the
|
||||||
// We also do not expose a public way to create variables as this would break the invariant that
|
// tensor within a variable is actually with `is_variable` set to `true`.
|
||||||
// the tensor within a variable is actually with `is_variable` set to `true`.
|
|
||||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||||
|
|
||||||
/// A variable is a wrapper around a tensor, however variables can have their content modified
|
/// A variable is a wrapper around a tensor, however variables can have their content modified
|
||||||
/// whereas tensors are immutable.
|
/// whereas tensors are immutable.
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Var(Tensor);
|
pub struct Var(Tensor);
|
||||||
|
|
||||||
impl std::ops::Deref for Var {
|
impl std::ops::Deref for Var {
|
||||||
|
@ -6,6 +6,7 @@ pub mod embedding;
|
|||||||
pub mod init;
|
pub mod init;
|
||||||
pub mod layer_norm;
|
pub mod layer_norm;
|
||||||
pub mod linear;
|
pub mod linear;
|
||||||
|
pub mod optim;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
|
|
||||||
pub use activation::Activation;
|
pub use activation::Activation;
|
||||||
@ -13,4 +14,5 @@ pub use conv::{Conv1d, Conv1dConfig};
|
|||||||
pub use embedding::Embedding;
|
pub use embedding::Embedding;
|
||||||
pub use layer_norm::LayerNorm;
|
pub use layer_norm::LayerNorm;
|
||||||
pub use linear::Linear;
|
pub use linear::Linear;
|
||||||
|
pub use optim::SGD;
|
||||||
pub use var_builder::VarBuilder;
|
pub use var_builder::VarBuilder;
|
||||||
|
47
candle-nn/src/optim.rs
Normal file
47
candle-nn/src/optim.rs
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
//! Various optimization algorithms.
|
||||||
|
use candle::{Result, Tensor, Var};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SGD {
|
||||||
|
vars: Vec<Var>,
|
||||||
|
learning_rate: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SGD {
|
||||||
|
pub fn new(vars: &[&Var], learning_rate: f64) -> Self {
|
||||||
|
let vars: Vec<_> = vars.iter().map(|&v| v.clone()).collect();
|
||||||
|
Self {
|
||||||
|
vars,
|
||||||
|
learning_rate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn empty(learning_rate: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
vars: vec![],
|
||||||
|
learning_rate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_inner(self) -> Vec<Var> {
|
||||||
|
self.vars
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn learning_rate(&self) -> f64 {
|
||||||
|
self.learning_rate
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, var: &Var) {
|
||||||
|
self.vars.push(var.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn backward_step(&self, loss: &Tensor) -> Result<()> {
|
||||||
|
let grads = loss.backward()?;
|
||||||
|
for var in self.vars.iter() {
|
||||||
|
if let Some(grad) = grads.get(var) {
|
||||||
|
var.set(&var.sub(&(grad * self.learning_rate)?)?)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
19
candle-nn/tests/optim.rs
Normal file
19
candle-nn/tests/optim.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{Device, Var};
|
||||||
|
use candle_nn::SGD;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sgd_optim() -> Result<()> {
|
||||||
|
let x = Var::new(0f32, &Device::Cpu)?;
|
||||||
|
let sgd = SGD::new(&[&x], 0.1);
|
||||||
|
let xt = x.as_tensor();
|
||||||
|
for _step in 0..100 {
|
||||||
|
let loss = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||||
|
sgd.backward_step(&loss)?
|
||||||
|
}
|
||||||
|
assert_eq!(x.to_scalar::<f32>()?, 4.199999);
|
||||||
|
Ok(())
|
||||||
|
}
|
Reference in New Issue
Block a user