Add the SGD optimizer (#160)

* Add the nn::optim and some conversion traits.

* Add the backward_step function for SGD.

* Get the SGD optimizer to work and add a test.

* Make the test slighly simpler.
This commit is contained in:
Laurent Mazare
2023-07-13 19:05:44 +01:00
committed by GitHub
parent 5ee3c95582
commit ded93a1169
6 changed files with 168 additions and 4 deletions

View File

@ -0,0 +1,96 @@
//! Implement conversion traits for tensors
use crate::{Device, Error, Tensor, WithDType};
use half::{bf16, f16};
use std::convert::TryFrom;
impl<T: WithDType> TryFrom<&Tensor> for Vec<T> {
type Error = Error;
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
tensor.to_vec1::<T>()
}
}
impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<T>> {
type Error = Error;
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
tensor.to_vec2::<T>()
}
}
impl<T: WithDType> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> {
type Error = Error;
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
tensor.to_vec3::<T>()
}
}
impl<T: WithDType> TryFrom<Tensor> for Vec<T> {
type Error = Error;
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
Vec::<T>::try_from(&tensor)
}
}
impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<T>> {
type Error = Error;
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
Vec::<Vec<T>>::try_from(&tensor)
}
}
impl<T: WithDType> TryFrom<Tensor> for Vec<Vec<Vec<T>>> {
type Error = Error;
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
Vec::<Vec<Vec<T>>>::try_from(&tensor)
}
}
impl<T: WithDType> TryFrom<&[T]> for Tensor {
type Error = Error;
fn try_from(v: &[T]) -> Result<Self, Self::Error> {
Tensor::from_slice(v, v.len(), &Device::Cpu)
}
}
impl<T: WithDType> TryFrom<Vec<T>> for Tensor {
type Error = Error;
fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
let len = v.len();
Tensor::from_vec(v, len, &Device::Cpu)
}
}
macro_rules! from_tensor {
($typ:ident) => {
impl TryFrom<&Tensor> for $typ {
type Error = Error;
fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
tensor.to_scalar::<$typ>()
}
}
impl TryFrom<Tensor> for $typ {
type Error = Error;
fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
$typ::try_from(&tensor)
}
}
impl TryFrom<$typ> for Tensor {
type Error = Error;
fn try_from(v: $typ) -> Result<Self, Self::Error> {
Tensor::new(v, &Device::Cpu)
}
}
};
}
from_tensor!(f64);
from_tensor!(f32);
from_tensor!(f16);
from_tensor!(bf16);
from_tensor!(u32);
from_tensor!(u8);

View File

@ -36,6 +36,7 @@
mod backend;
mod backprop;
mod conv;
mod convert;
mod cpu_backend;
#[cfg(feature = "cuda")]
mod cuda_backend;

View File

@ -1,13 +1,12 @@
// Variables are wrappers around tensors that can be modified, they are typically used for holding
// weights and being modified by gradient descent.
// They are not cloneable by default to avoid having too many potential writers on the data.
// We also do not expose a public way to create variables as this would break the invariant that
// the tensor within a variable is actually with `is_variable` set to `true`.
// We do not expose a public way to create variables as this would break the invariant that the
// tensor within a variable is actually with `is_variable` set to `true`.
use crate::{DType, Device, Error, Result, Shape, Tensor};
/// A variable is a wrapper around a tensor, however variables can have their content modified
/// whereas tensors are immutable.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct Var(Tensor);
impl std::ops::Deref for Var {