diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 2ab2ec1d..8746c2fe 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -109,6 +109,9 @@ pub enum Error { #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, + #[error("cannot set variable {msg}")] + CannotSetVar { msg: &'static str }, + // Box indirection to avoid large variant. #[error("{0:?}")] MatMulUnexpectedStriding(Box), diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index af1eb215..254e2c99 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -68,7 +68,7 @@ pub use shape::{Shape, D}; pub use storage::Storage; use strided_index::StridedIndex; pub use tensor::{Tensor, TensorId}; -pub use variable::Variable; +pub use variable::Var; #[cfg(feature = "cuda")] pub use cuda_backend::{CudaDevice, CudaStorage}; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ffd190ca..481a6851 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -148,7 +148,7 @@ fn from_storage>( } impl Tensor { - fn ones_impl>( + pub(crate) fn ones_impl>( shape: S, dtype: DType, device: &Device, @@ -171,12 +171,6 @@ impl Tensor { Self::ones_impl(shape, dtype, device, false) } - pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { - // Maybe we should allocate some actual storage for vars rather than just using a - // broadcasted scalar? - Self::ones_impl(shape, dtype, device, true) - } - /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor. /// /// ```rust @@ -190,16 +184,9 @@ impl Tensor { Tensor::ones(self.shape(), self.dtype(), self.device()) } - /// Creates a new tensor filled with zeros. - /// - /// ```rust - /// use candle::{Tensor, DType, Device}; - /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; - /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; - /// // a == b - /// # Ok::<(), candle::Error>(()) - /// ``` - fn zeros_impl>( + // Do not expose outside of the crate, the `is_variable=true` case should only be accessed from + // the variable module. + pub(crate) fn zeros_impl>( shape: S, dtype: DType, device: &Device, @@ -222,10 +209,6 @@ impl Tensor { Self::zeros_impl(shape, dtype, device, false) } - pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { - Self::zeros_impl(shape, dtype, device, true) - } - /// Creates a new tensor filled with ones with same shape, dtype, and device as the other /// tensor. /// @@ -240,7 +223,7 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), self.device()) } - fn rand_impl>( + pub(crate) fn rand_impl>( s: S, dtype: DType, device: &Device, @@ -264,17 +247,7 @@ impl Tensor { Self::rand_impl(s, dtype, device, lo, up, false) } - pub fn rand_var>( - s: S, - dtype: DType, - device: &Device, - lo: f64, - up: f64, - ) -> Result { - Self::rand_impl(s, dtype, device, lo, up, true) - } - - fn randn_impl>( + pub(crate) fn randn_impl>( s: S, dtype: DType, device: &Device, @@ -299,17 +272,7 @@ impl Tensor { Self::randn_impl(s, dtype, device, mean, std, false) } - pub fn randn_var>( - s: S, - dtype: DType, - device: &Device, - mean: f64, - std: f64, - ) -> Result { - Self::randn_impl(s, dtype, device, mean, std, true) - } - - pub fn new_impl( + pub(crate) fn new_impl( array: A, shape: Shape, device: &Device, @@ -330,13 +293,6 @@ impl Tensor { Self::new_impl(array, shape, device, false) } - /// Creates a new tensor on the specified device using the content and shape of the input. - /// This is similar to `new` but the resulting tensor is a variable. - pub fn var(array: A, device: &Device) -> Result { - let shape = array.shape()?; - Self::new_impl(array, shape, device, true) - } - /// Creates a new 1D tensor from an iterator. pub fn from_iter( iter: impl IntoIterator, @@ -371,7 +327,7 @@ impl Tensor { Self::from_vec_impl(data, len, device, false) } - fn from_vec_impl, D: crate::WithDType>( + pub(crate) fn from_vec_impl, D: crate::WithDType>( data: Vec, shape: S, device: &Device, @@ -397,14 +353,6 @@ impl Tensor { Self::from_vec_impl(data, shape, device, false) } - pub fn var_from_vec, D: crate::WithDType>( - data: Vec, - shape: S, - device: &Device, - ) -> Result { - Self::from_vec_impl(data, shape, device, true) - } - /// Creates a new tensor initialized with values from the input slice. The number of elements /// in this vector must be the same as the number of elements defined by the shape. pub fn from_slice, D: crate::WithDType>( @@ -415,14 +363,6 @@ impl Tensor { Self::new_impl(array, shape.into(), device, false) } - pub fn var_from_slice, D: crate::WithDType>( - array: &[D], - shape: S, - device: &Device, - ) -> Result { - Self::new_impl(array, shape.into(), device, true) - } - pub(crate) fn broadcast_shape_binary_op<'a>( &'a self, rhs: &'a Self, @@ -1532,11 +1472,26 @@ impl Tensor { self.storage.read().unwrap() } + // If we extend the visibility of this function to be usable outside of this crate, we should + // make it unsafe. + pub(crate) fn storage_mut_and_layout( + &self, + ) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) { + let storage = self.storage.write().unwrap(); + (storage, &self.layout) + } + /// The storage used by this tensor, together with the layout to use to access it safely. pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) { let storage = self.storage.read().unwrap(); (storage, &self.layout) } + + pub(crate) fn same_storage(&self, rhs: &Self) -> bool { + let lhs: &RwLock = self.storage.as_ref(); + let rhs: &RwLock = rhs.storage.as_ref(); + std::ptr::eq(lhs, rhs) + } } macro_rules! bin_trait { diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index 67675765..b9051ed6 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -3,14 +3,14 @@ // They are not cloneable by default to avoid having too many potential writers on the data. // We also do not expose a public way to create variables as this would break the invariant that // the tensor within a variable is actually with `is_variable` set to `true`. -use crate::Tensor; +use crate::{DType, Device, Error, Result, Shape, Tensor}; /// A variable is a wrapper around a tensor, however variables can have their content modified /// whereas tensors are immutable. #[derive(Debug)] -pub struct Variable(Tensor); +pub struct Var(Tensor); -impl std::ops::Deref for Variable { +impl std::ops::Deref for Var { type Target = Tensor; fn deref(&self) -> &Self::Target { @@ -18,13 +18,95 @@ impl std::ops::Deref for Variable { } } -impl Variable { +impl Var { + pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { + let inner = Tensor::zeros_impl(shape, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { + let inner = Tensor::ones_impl(shape, dtype, device, true)?; + Ok(Self(inner)) + } + + pub fn rand>( + s: S, + dtype: DType, + device: &Device, + lo: f64, + up: f64, + ) -> Result { + let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?; + Ok(Self(inner)) + } + + pub fn randn>( + s: S, + dtype: DType, + device: &Device, + mean: f64, + std: f64, + ) -> Result { + let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?; + Ok(Self(inner)) + } + + /// Creates a new tensor on the specified device using the content and shape of the input. + /// This is similar to `new` but the resulting tensor is a variable. + pub fn new(array: A, device: &Device) -> Result { + let shape = array.shape()?; + let inner = Tensor::new_impl(array, shape, device, true)?; + Ok(Self(inner)) + } + + pub fn from_vec, D: crate::WithDType>( + data: Vec, + shape: S, + device: &Device, + ) -> Result { + let inner = Tensor::from_vec_impl(data, shape, device, true)?; + Ok(Self(inner)) + } + + pub fn from_slice, D: crate::WithDType>( + array: &[D], + shape: S, + device: &Device, + ) -> Result { + let inner = Tensor::new_impl(array, shape.into(), device, true)?; + Ok(Self(inner)) + } + pub fn as_tensor(&self) -> &Tensor { &self.0 } - /// Consumes this `Variable` and return the underlying tensor. + /// Consumes this `Var` and return the underlying tensor. pub fn into_inner(self) -> Tensor { self.0 } + + /// Sets the content of the inner tensor, this does not require a mutable reference as inner + /// mutability is used. + pub fn set(&self, src: &Tensor) -> Result<()> { + if self.same_storage(src) { + let msg = "cannot set a variable to a tensor that is derived from its value"; + Err(Error::CannotSetVar { msg })? + } + let (mut dst, layout) = self.storage_mut_and_layout(); + if !layout.is_contiguous() { + let msg = "cannot set a non-contiguous variable"; + Err(Error::CannotSetVar { msg })? + } + let (src, src_l) = src.storage_and_layout(); + if layout.shape() != src_l.shape() { + Err(Error::ShapeMismatchBinaryOp { + lhs: layout.shape().clone(), + rhs: src_l.shape().clone(), + op: "set", + })? + } + src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?; + Ok(()) + } } diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index d5c8f751..6f11879f 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -1,12 +1,13 @@ use anyhow::{Context, Result}; -use candle::{Device, Shape, Tensor}; +use candle::{Device, Shape, Var}; mod test_utils; fn simple_grad(device: &Device) -> Result<()> { - let x = Tensor::var(&[3f32, 1., 4.], device)?; - let y = (((&x * &x)? + &x * 5f64)? + 4f64)?; + let x = Var::new(&[3f32, 1., 4.], device)?; + let x = x.as_tensor(); + let y = (((x * x)? + x * 5f64)? + 4f64)?; let grads = y.backward()?; - let grad_x = grads.get(&x).context("no grad for x")?; + let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(x.to_vec1::()?, [3., 1., 4.]); // y = x^2 + 5.x + 4 assert_eq!(y.to_vec1::()?, [28., 10., 40.]); @@ -17,9 +18,9 @@ fn simple_grad(device: &Device) -> Result<()> { fn matmul_grad(device: &Device) -> Result<()> { let data: Vec<_> = (0..12).map(|i| i as f32).collect(); - let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?; + let x = Var::from_slice(&data, (2, 2, 3), device)?; let data: Vec<_> = (0..12).map(|i| i as f32).collect(); - let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?; + let y = Var::from_slice(&data, (2, 3, 2), device)?; let c = x.matmul(&y)?; let grads = c.backward()?; let grad_x = grads.get(&x).context("no grad for x")?; @@ -43,5 +44,21 @@ fn matmul_grad(device: &Device) -> Result<()> { Ok(()) } +// The simplest gradient descent, using scalar variable. +fn grad_descent(device: &Device) -> Result<()> { + let x = Var::new(0f32, device)?; + let learning_rate = 0.1; + for _step in 0..100 { + let xt = x.as_tensor(); + let c = ((xt - 4.2)? * (xt - 4.2)?)?; + let grads = c.backward()?; + let x_grad = grads.get(&x).context("no grad for x")?; + x.set(&(xt - x_grad * learning_rate)?)? + } + assert_eq!(x.to_scalar::()?, 4.199999); + Ok(()) +} + test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu); test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu); +test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);