mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Move the variable creation to the variable module. (#159)
* Move the variable creation to the variable module. * Make it possible to set a variable. * Add some basic gradient descent test. * Get the gradient descent test to work.
This commit is contained in:
@ -109,6 +109,9 @@ pub enum Error {
|
|||||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||||
|
|
||||||
|
#[error("cannot set variable {msg}")]
|
||||||
|
CannotSetVar { msg: &'static str },
|
||||||
|
|
||||||
// Box indirection to avoid large variant.
|
// Box indirection to avoid large variant.
|
||||||
#[error("{0:?}")]
|
#[error("{0:?}")]
|
||||||
MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
|
MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
|
||||||
|
@ -68,7 +68,7 @@ pub use shape::{Shape, D};
|
|||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
use strided_index::StridedIndex;
|
use strided_index::StridedIndex;
|
||||||
pub use tensor::{Tensor, TensorId};
|
pub use tensor::{Tensor, TensorId};
|
||||||
pub use variable::Variable;
|
pub use variable::Var;
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||||
|
@ -148,7 +148,7 @@ fn from_storage<S: Into<Shape>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
fn ones_impl<S: Into<Shape>>(
|
pub(crate) fn ones_impl<S: Into<Shape>>(
|
||||||
shape: S,
|
shape: S,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -171,12 +171,6 @@ impl Tensor {
|
|||||||
Self::ones_impl(shape, dtype, device, false)
|
Self::ones_impl(shape, dtype, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
// Maybe we should allocate some actual storage for vars rather than just using a
|
|
||||||
// broadcasted scalar?
|
|
||||||
Self::ones_impl(shape, dtype, device, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
@ -190,16 +184,9 @@ impl Tensor {
|
|||||||
Tensor::ones(self.shape(), self.dtype(), self.device())
|
Tensor::ones(self.shape(), self.dtype(), self.device())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor filled with zeros.
|
// Do not expose outside of the crate, the `is_variable=true` case should only be accessed from
|
||||||
///
|
// the variable module.
|
||||||
/// ```rust
|
pub(crate) fn zeros_impl<S: Into<Shape>>(
|
||||||
/// use candle::{Tensor, DType, Device};
|
|
||||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
|
||||||
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
|
|
||||||
/// // a == b
|
|
||||||
/// # Ok::<(), candle::Error>(())
|
|
||||||
/// ```
|
|
||||||
fn zeros_impl<S: Into<Shape>>(
|
|
||||||
shape: S,
|
shape: S,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -222,10 +209,6 @@ impl Tensor {
|
|||||||
Self::zeros_impl(shape, dtype, device, false)
|
Self::zeros_impl(shape, dtype, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
|
||||||
Self::zeros_impl(shape, dtype, device, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
|
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
|
||||||
/// tensor.
|
/// tensor.
|
||||||
///
|
///
|
||||||
@ -240,7 +223,7 @@ impl Tensor {
|
|||||||
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
Tensor::zeros(self.shape(), self.dtype(), self.device())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_impl<S: Into<Shape>>(
|
pub(crate) fn rand_impl<S: Into<Shape>>(
|
||||||
s: S,
|
s: S,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -264,17 +247,7 @@ impl Tensor {
|
|||||||
Self::rand_impl(s, dtype, device, lo, up, false)
|
Self::rand_impl(s, dtype, device, lo, up, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rand_var<S: Into<Shape>>(
|
pub(crate) fn randn_impl<S: Into<Shape>>(
|
||||||
s: S,
|
|
||||||
dtype: DType,
|
|
||||||
device: &Device,
|
|
||||||
lo: f64,
|
|
||||||
up: f64,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Self::rand_impl(s, dtype, device, lo, up, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn randn_impl<S: Into<Shape>>(
|
|
||||||
s: S,
|
s: S,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -299,17 +272,7 @@ impl Tensor {
|
|||||||
Self::randn_impl(s, dtype, device, mean, std, false)
|
Self::randn_impl(s, dtype, device, mean, std, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn randn_var<S: Into<Shape>>(
|
pub(crate) fn new_impl<A: crate::device::NdArray>(
|
||||||
s: S,
|
|
||||||
dtype: DType,
|
|
||||||
device: &Device,
|
|
||||||
mean: f64,
|
|
||||||
std: f64,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Self::randn_impl(s, dtype, device, mean, std, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_impl<A: crate::device::NdArray>(
|
|
||||||
array: A,
|
array: A,
|
||||||
shape: Shape,
|
shape: Shape,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -330,13 +293,6 @@ impl Tensor {
|
|||||||
Self::new_impl(array, shape, device, false)
|
Self::new_impl(array, shape, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor on the specified device using the content and shape of the input.
|
|
||||||
/// This is similar to `new` but the resulting tensor is a variable.
|
|
||||||
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
|
||||||
let shape = array.shape()?;
|
|
||||||
Self::new_impl(array, shape, device, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new 1D tensor from an iterator.
|
/// Creates a new 1D tensor from an iterator.
|
||||||
pub fn from_iter<D: crate::WithDType>(
|
pub fn from_iter<D: crate::WithDType>(
|
||||||
iter: impl IntoIterator<Item = D>,
|
iter: impl IntoIterator<Item = D>,
|
||||||
@ -371,7 +327,7 @@ impl Tensor {
|
|||||||
Self::from_vec_impl(data, len, device, false)
|
Self::from_vec_impl(data, len, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||||
data: Vec<D>,
|
data: Vec<D>,
|
||||||
shape: S,
|
shape: S,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -397,14 +353,6 @@ impl Tensor {
|
|||||||
Self::from_vec_impl(data, shape, device, false)
|
Self::from_vec_impl(data, shape, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var_from_vec<S: Into<Shape>, D: crate::WithDType>(
|
|
||||||
data: Vec<D>,
|
|
||||||
shape: S,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Self::from_vec_impl(data, shape, device, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
||||||
/// in this vector must be the same as the number of elements defined by the shape.
|
/// in this vector must be the same as the number of elements defined by the shape.
|
||||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||||
@ -415,14 +363,6 @@ impl Tensor {
|
|||||||
Self::new_impl(array, shape.into(), device, false)
|
Self::new_impl(array, shape.into(), device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
|
|
||||||
array: &[D],
|
|
||||||
shape: S,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Self::new_impl(array, shape.into(), device, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn broadcast_shape_binary_op<'a>(
|
pub(crate) fn broadcast_shape_binary_op<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
rhs: &'a Self,
|
rhs: &'a Self,
|
||||||
@ -1532,11 +1472,26 @@ impl Tensor {
|
|||||||
self.storage.read().unwrap()
|
self.storage.read().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we extend the visibility of this function to be usable outside of this crate, we should
|
||||||
|
// make it unsafe.
|
||||||
|
pub(crate) fn storage_mut_and_layout(
|
||||||
|
&self,
|
||||||
|
) -> (std::sync::RwLockWriteGuard<'_, Storage>, &Layout) {
|
||||||
|
let storage = self.storage.write().unwrap();
|
||||||
|
(storage, &self.layout)
|
||||||
|
}
|
||||||
|
|
||||||
/// The storage used by this tensor, together with the layout to use to access it safely.
|
/// The storage used by this tensor, together with the layout to use to access it safely.
|
||||||
pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
|
pub fn storage_and_layout(&self) -> (std::sync::RwLockReadGuard<'_, Storage>, &Layout) {
|
||||||
let storage = self.storage.read().unwrap();
|
let storage = self.storage.read().unwrap();
|
||||||
(storage, &self.layout)
|
(storage, &self.layout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn same_storage(&self, rhs: &Self) -> bool {
|
||||||
|
let lhs: &RwLock<Storage> = self.storage.as_ref();
|
||||||
|
let rhs: &RwLock<Storage> = rhs.storage.as_ref();
|
||||||
|
std::ptr::eq(lhs, rhs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -3,14 +3,14 @@
|
|||||||
// They are not cloneable by default to avoid having too many potential writers on the data.
|
// They are not cloneable by default to avoid having too many potential writers on the data.
|
||||||
// We also do not expose a public way to create variables as this would break the invariant that
|
// We also do not expose a public way to create variables as this would break the invariant that
|
||||||
// the tensor within a variable is actually with `is_variable` set to `true`.
|
// the tensor within a variable is actually with `is_variable` set to `true`.
|
||||||
use crate::Tensor;
|
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||||
|
|
||||||
/// A variable is a wrapper around a tensor, however variables can have their content modified
|
/// A variable is a wrapper around a tensor, however variables can have their content modified
|
||||||
/// whereas tensors are immutable.
|
/// whereas tensors are immutable.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Variable(Tensor);
|
pub struct Var(Tensor);
|
||||||
|
|
||||||
impl std::ops::Deref for Variable {
|
impl std::ops::Deref for Var {
|
||||||
type Target = Tensor;
|
type Target = Tensor;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
@ -18,13 +18,95 @@ impl std::ops::Deref for Variable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Variable {
|
impl Var {
|
||||||
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let inner = Tensor::zeros_impl(shape, dtype, device, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
let inner = Tensor::ones_impl(shape, dtype, device, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rand<S: Into<Shape>>(
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
lo: f64,
|
||||||
|
up: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let inner = Tensor::rand_impl(s, dtype, device, lo, up, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn randn<S: Into<Shape>>(
|
||||||
|
s: S,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
mean: f64,
|
||||||
|
std: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let inner = Tensor::randn_impl(s, dtype, device, mean, std, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new tensor on the specified device using the content and shape of the input.
|
||||||
|
/// This is similar to `new` but the resulting tensor is a variable.
|
||||||
|
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||||
|
let shape = array.shape()?;
|
||||||
|
let inner = Tensor::new_impl(array, shape, device, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||||
|
data: Vec<D>,
|
||||||
|
shape: S,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let inner = Tensor::from_vec_impl(data, shape, device, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||||
|
array: &[D],
|
||||||
|
shape: S,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let inner = Tensor::new_impl(array, shape.into(), device, true)?;
|
||||||
|
Ok(Self(inner))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn as_tensor(&self) -> &Tensor {
|
pub fn as_tensor(&self) -> &Tensor {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Consumes this `Variable` and return the underlying tensor.
|
/// Consumes this `Var` and return the underlying tensor.
|
||||||
pub fn into_inner(self) -> Tensor {
|
pub fn into_inner(self) -> Tensor {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sets the content of the inner tensor, this does not require a mutable reference as inner
|
||||||
|
/// mutability is used.
|
||||||
|
pub fn set(&self, src: &Tensor) -> Result<()> {
|
||||||
|
if self.same_storage(src) {
|
||||||
|
let msg = "cannot set a variable to a tensor that is derived from its value";
|
||||||
|
Err(Error::CannotSetVar { msg })?
|
||||||
|
}
|
||||||
|
let (mut dst, layout) = self.storage_mut_and_layout();
|
||||||
|
if !layout.is_contiguous() {
|
||||||
|
let msg = "cannot set a non-contiguous variable";
|
||||||
|
Err(Error::CannotSetVar { msg })?
|
||||||
|
}
|
||||||
|
let (src, src_l) = src.storage_and_layout();
|
||||||
|
if layout.shape() != src_l.shape() {
|
||||||
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: layout.shape().clone(),
|
||||||
|
rhs: src_l.shape().clone(),
|
||||||
|
op: "set",
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle::{Device, Shape, Tensor};
|
use candle::{Device, Shape, Var};
|
||||||
mod test_utils;
|
mod test_utils;
|
||||||
|
|
||||||
fn simple_grad(device: &Device) -> Result<()> {
|
fn simple_grad(device: &Device) -> Result<()> {
|
||||||
let x = Tensor::var(&[3f32, 1., 4.], device)?;
|
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||||
let y = (((&x * &x)? + &x * 5f64)? + 4f64)?;
|
let x = x.as_tensor();
|
||||||
|
let y = (((x * x)? + x * 5f64)? + 4f64)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(x.to_vec1::<f32>()?, [3., 1., 4.]);
|
assert_eq!(x.to_vec1::<f32>()?, [3., 1., 4.]);
|
||||||
// y = x^2 + 5.x + 4
|
// y = x^2 + 5.x + 4
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [28., 10., 40.]);
|
assert_eq!(y.to_vec1::<f32>()?, [28., 10., 40.]);
|
||||||
@ -17,9 +18,9 @@ fn simple_grad(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
fn matmul_grad(device: &Device) -> Result<()> {
|
fn matmul_grad(device: &Device) -> Result<()> {
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?;
|
let x = Var::from_slice(&data, (2, 2, 3), device)?;
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?;
|
let y = Var::from_slice(&data, (2, 3, 2), device)?;
|
||||||
let c = x.matmul(&y)?;
|
let c = x.matmul(&y)?;
|
||||||
let grads = c.backward()?;
|
let grads = c.backward()?;
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
@ -43,5 +44,21 @@ fn matmul_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The simplest gradient descent, using scalar variable.
|
||||||
|
fn grad_descent(device: &Device) -> Result<()> {
|
||||||
|
let x = Var::new(0f32, device)?;
|
||||||
|
let learning_rate = 0.1;
|
||||||
|
for _step in 0..100 {
|
||||||
|
let xt = x.as_tensor();
|
||||||
|
let c = ((xt - 4.2)? * (xt - 4.2)?)?;
|
||||||
|
let grads = c.backward()?;
|
||||||
|
let x_grad = grads.get(&x).context("no grad for x")?;
|
||||||
|
x.set(&(xt - x_grad * learning_rate)?)?
|
||||||
|
}
|
||||||
|
assert_eq!(x.to_scalar::<f32>()?, 4.199999);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
test_device!(simple_grad, simple_grad_cpu, simple_grad_gpu);
|
||||||
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
test_device!(matmul_grad, matmul_grad_cpu, matmul_grad_gpu);
|
||||||
|
test_device!(grad_descent, grad_descent_cpu, grad_descent_gpu);
|
||||||
|
Reference in New Issue
Block a user