Move the variable creation to the variable module. (#159)

* Move the variable creation to the variable module.

* Make it possible to set a variable.

* Add some basic gradient descent test.

* Get the gradient descent test to work.
This commit is contained in:
Laurent Mazare
2023-07-13 16:55:40 +01:00
committed by GitHub
parent 6991036bc5
commit 5ee3c95582
5 changed files with 137 additions and 80 deletions

View File

@ -109,6 +109,9 @@ pub enum Error {
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
#[error("cannot set variable {msg}")]
CannotSetVar { msg: &'static str },
// Box indirection to avoid large variant. // Box indirection to avoid large variant.
#[error("{0:?}")] #[error("{0:?}")]
MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>), MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),

View File

@ -68,7 +68,7 @@ pub use shape::{Shape, D};
pub use storage::Storage; pub use storage::Storage;
use strided_index::StridedIndex; use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId}; pub use tensor::{Tensor, TensorId};
pub use variable::Variable; pub use variable::Var;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage}; pub use cuda_backend::{CudaDevice, CudaStorage};

View File

@ -148,7 +148,7 @@ fn from_storage<S: Into<Shape>>(
} }
impl Tensor { impl Tensor {
fn ones_impl<S: Into<Shape>>( pub(crate) fn ones_impl<S: Into<Shape>>(
shape: S, shape: S,
dtype: DType, dtype: DType,
device: &Device, device: &Device,
@ -171,12 +171,6 @@ impl Tensor {
Self::ones_impl(shape, dtype, device, false) Self::ones_impl(shape, dtype, device, false)
} }
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
// Maybe we should allocate some actual storage for vars rather than just using a
// broadcasted scalar?
Self::ones_impl(shape, dtype, device, true)
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor. /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
/// ///
/// ```rust /// ```rust
@ -190,16 +184,9 @@ impl Tensor {
Tensor::ones(self.shape(), self.dtype(), self.device()) Tensor::ones(self.shape(), self.dtype(), self.device())
} }
/// Creates a new tensor filled with zeros. // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
/// // the variable module.
/// ```rust pub(crate) fn zeros_impl<S: Into<Shape>>(
/// use candle::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
/// // a == b
/// # Ok::<(), candle::Error>(())
/// ```
fn zeros_impl<S: Into<Shape>>(
shape: S, shape: S,
dtype: DType, dtype: DType,
device: &Device, device: &Device,
@ -222,10 +209,6 @@ impl Tensor {
Self::zeros_impl(shape, dtype, device, false) Self::zeros_impl(shape, dtype, device, false)
} }
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, true)
}
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other /// Creates a new tensor filled with ones with same shape, dtype, and device as the other
/// tensor. /// tensor.
/// ///
@ -240,7 +223,7 @@ impl Tensor {
Tensor::zeros(self.shape(), self.dtype(), self.device()) Tensor::zeros(self.shape(), self.dtype(), self.device())
} }
fn rand_impl<S: Into<Shape>>( pub(crate) fn rand_impl<S: Into<Shape>>(
s: S, s: S,
dtype: DType, dtype: DType,
device: &Device, device: &Device,
@ -264,17 +247,7 @@ impl Tensor {
Self::rand_impl(s, dtype, device, lo, up, false) Self::rand_impl(s, dtype, device, lo, up, false)
} }
pub fn rand_var<S: Into<Shape>>( pub(crate) fn randn_impl<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
lo: f64,
up: f64,
) -> Result<Self> {
Self::rand_impl(s, dtype, device, lo, up, true)
}
fn randn_impl<S: Into<Shape>>(
s: S, s: S,
dtype: DType, dtype: DType,
device: &Device, device: &Device,
@ -299,17 +272,7 @@ impl Tensor {
Self::randn_impl(s, dtype, device, mean, std, false) Self::randn_impl(s, dtype, device, mean, std, false)
} }
pub fn randn_var<S: Into<Shape>>( pub(crate) fn new_impl<A: crate::device::NdArray>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
Self::randn_impl(s, dtype, device, mean, std, true)
}
pub fn new_impl<A: crate::device::NdArray>(
array: A, array: A,
shape: Shape, shape: Shape,
device: &Device, device: &Device,
@ -330,13 +293,6 @@ impl Tensor {
Self::new_impl(array, shape, device, false) Self::new_impl(array, shape, device, false)
} }
/// Creates a new tensor on the specified device using the content and shape of the input.
/// This is similar to `new` but the resulting tensor is a variable.
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?;
Self::new_impl(array, shape, device, true)
}
/// Creates a new 1D tensor from an iterator. /// Creates a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>( pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>, iter: impl IntoIterator<Item = D>,
@ -371,7 +327,7 @@ impl Tensor {
Self::from_vec_impl(data, len, device, false) Self::from_vec_impl(data, len, device, false)
} }
fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>( pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>, data: Vec<D>,
shape: S, shape: S,
device: &Device, device: &Device,
@ -397,14 +353,6 @@ impl Tensor {
Self::from_vec_impl(data, shape, device, false) Self::from_vec_impl(data, shape, device, false)
} }
pub fn var_from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(data, shape, device, true)
}
/// Creates a new tensor initialized with values from the input slice. The number of elements /// Creates a new tensor initialized with values from the input slice. The number of elements
/// in this vector must be the same as the number of elements defined by the shape. /// in this vector must be the same as the number of elements defined by the shape.
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>( pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
@ -415,14 +363,6 @@ impl Tensor {
Self::new_impl(array, shape.into(), device, false) Self::new_impl(array, shape.into(), device, false)
} }
pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
Self::new_impl(array, shape.into(), device, true)
}
pub(crate) fn broadcast_shape_binary_op<'a>( pub(crate) fn broadcast_shape_binary_op<'a>(
&'a self, &'a self,
rhs: &'a Self, rhs: &'a Self,
@ -1532,11 +1472,26 @@ impl Tensor {
self.storage.read().unwrap() self.storage.read().unwrap()
} }
// If we extend the visibility of this function to be usable outside of this crate, we should
// make it unsafe.
pub(crate) fn storage_mut_and_layout(
&self,
) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
let storage = self.storage.write().unwrap();
(storage, &self.layout)
}
/// The storage used by this tensor, together with the layout to use to access it safely. /// The storage used by this tensor, together with the layout to use to access it safely.
pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) { pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
let storage = self.storage.read().unwrap(); let storage = self.storage.read().unwrap();
(storage, &self.layout) (storage, &self.layout)
} }
pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
let lhs: &RwLock<Storage> = self.storage.as_ref();
let rhs: &RwLock<Storage> = rhs.storage.as_ref();
std::ptr::eq(lhs, rhs)
}
} }
macro_rules! bin_trait { macro_rules! bin_trait {

View File

@ -3,14 +3,14 @@
// They are not cloneable by default to avoid having too many potential writers on the data. // They are not cloneable by default to avoid having too many potential writers on the data.
// We also do not expose a public way to create variables as this would break the invariant that // We also do not expose a public way to create variables as this would break the invariant that
// the tensor within a variable is actually with `is_variable` set to `true`. // the tensor within a variable is actually with `is_variable` set to `true`.
use crate::Tensor; use crate::{DType, Device, Error, Result, Shape, Tensor};
/// A variable is a wrapper around a tensor, however variables can have their content modified /// A variable is a wrapper around a tensor, however variables can have their content modified
/// whereas tensors are immutable. /// whereas tensors are immutable.
#[derive(Debug)] #[derive(Debug)]
pub struct Variable(Tensor); pub struct Var(Tensor);
impl std::ops::Deref for Variable { impl std::ops::Deref for Var {
type Target = Tensor; type Target = Tensor;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -18,13 +18,95 @@ impl std::ops::Deref for Variable {
} }
} }
impl Variable { impl Var {
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
let inner = Tensor::zeros_impl(shape, dtype, device, true)?;
Ok(Self(inner))
}
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
let inner = Tensor::ones_impl(shape, dtype, device, true)?;
Ok(Self(inner))
}
pub fn rand<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
lo: f64,
up: f64,
) -> Result<Self> {
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
Ok(Self(inner))
}
pub fn randn<S: Into<Shape>>(
s: S,
dtype: DType,
device: &Device,
mean: f64,
std: f64,
) -> Result<Self> {
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
Ok(Self(inner))
}
/// Creates a new tensor on the specified device using the content and shape of the input.
/// This is similar to `new` but the resulting tensor is a variable.
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
let shape = array.shape()?;
let inner = Tensor::new_impl(array, shape, device, true)?;
Ok(Self(inner))
}
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
) -> Result<Self> {
let inner = Tensor::from_vec_impl(data, shape, device, true)?;
Ok(Self(inner))
}
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
device: &Device,
) -> Result<Self> {
let inner = Tensor::new_impl(array, shape.into(), device, true)?;
Ok(Self(inner))
}
pub fn as_tensor(&self) -> &Tensor { pub fn as_tensor(&self) -> &Tensor {
&self.0 &self.0
} }
/// Consumes this `Variable` and return the underlying tensor. /// Consumes this `Var` and return the underlying tensor.
pub fn into_inner(self) -> Tensor { pub fn into_inner(self) -> Tensor {
self.0 self.0
} }
/// Sets the content of the inner tensor, this does not require a mutable reference as inner
/// mutability is used.
pub fn set(&self, src: &Tensor) -> Result<()> {
if self.same_storage(src) {
let msg = "cannot set a variable to a tensor that is derived from its value";
Err(Error::CannotSetVar { msg })?
}
let (mut dst, layout) = self.storage_mut_and_layout();
if !layout.is_contiguous() {
let msg = "cannot set a non-contiguous variable";
Err(Error::CannotSetVar { msg })?
}
let (src, src_l) = src.storage_and_layout();
if layout.shape() != src_l.shape() {
Err(Error::ShapeMismatchBinaryOp {
lhs: layout.shape().clone(),
rhs: src_l.shape().clone(),
op: "set",
})?
}
src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
Ok(())
}
} }

View File

@ -1,12 +1,13 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use candle::{Device, Shape, Tensor}; use candle::{Device, Shape, Var};
mod test_utils; mod test_utils;
fn simple_grad(device: &Device) -> Result<()> { fn simple_grad(device: &Device) -> Result<()> {
let x = Tensor::var(&[3f32, 1., 4.], device)?; let x = Var::new(&[3f32, 1., 4.], device)?;
let y = (((&x * &x)? + &x * 5f64)? + 4f64)?; let x = x.as_tensor();
let y = (((x * x)? + x * 5f64)? + 4f64)?;
let grads = y.backward()?; let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?; let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(x.to_vec1::<f32>()?, [3., 1., 4.]); assert_eq!(x.to_vec1::<f32>()?, [3., 1., 4.]);
// y = x^2 + 5.x + 4 // y = x^2 + 5.x + 4
assert_eq!(y.to_vec1::<f32>()?, [28., 10., 40.]); assert_eq!(y.to_vec1::<f32>()?, [28., 10., 40.]);
@ -17,9 +18,9 @@ fn simple_grad(device: &Device) -> Result<()> {
fn matmul_grad(device: &Device) -> Result<()> { fn matmul_grad(device: &Device) -> Result<()> {
let data: Vec<_> = (0..12).map(|i| i as f32).collect(); let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?; let x = Var::from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| i as f32).collect(); let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?; let y = Var::from_slice(&data, (2, 3, 2), device)?;
let c = x.matmul(&y)?; let c = x.matmul(&y)?;
let grads = c.backward()?; let grads = c.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?; let grad_x = grads.get(&x).context("no grad for x")?;
@ -43,5 +44,21 @@ fn matmul_grad(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
// The simplest gradient descent, using scalar variable.
fn grad_descent(device: &Device) -> Result<()> {
let x = Var::new(0f32, device)?;
let learning_rate = 0.1;
for _step in 0..100 {
let xt = x.as_tensor();
let c = ((xt - 4.2)? * (xt - 4.2)?)?;
let grads = c.backward()?;
let x_grad = grads.get(&x).context("no grad for x")?;
x.set(&(xt - x_grad * learning_rate)?)?
}
assert_eq!(x.to_scalar::<f32>()?, 4.199999);
Ok(())
}
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu); test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu); test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);