From a86ec4b9f030fdc952dabd1ca975f6a22c177402 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 12 Jul 2023 17:40:17 +0100 Subject: [PATCH] Add more documentation and examples. (#149) * Add more documentation and examples. * More documentation and tests. * Document more tensor functions. * Again more examples and tests. --- candle-core/src/tensor.rs | 215 +++++++++++++++++++++++++++++++++----- 1 file changed, 189 insertions(+), 26 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 947c6b77..a174edd0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -140,7 +140,7 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } - /// Create a new tensors filled with ones + /// Creates a new tensor filled with ones. /// /// ```rust /// use candle::{Tensor, DType, Device}; @@ -159,8 +159,7 @@ impl Tensor { Self::ones_impl(shape, dtype, device, true) } - /// Create a new tensors filled with ones with same shape, dtype, and device - /// as the other tensors + /// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor. /// /// ```rust /// use candle::{Tensor, DType, Device}; @@ -173,7 +172,7 @@ impl Tensor { Tensor::ones(self.shape(), self.dtype(), &self.device()) } - /// Create a new tensors filled with zeros + /// Creates a new tensor filled with zeros. /// /// ```rust /// use candle::{Tensor, DType, Device}; @@ -192,7 +191,7 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } - /// Create a new tensors filled with zeros + /// Creates a new tensor filled with zeros. /// /// ```rust /// use candle::{Tensor, DType, Device}; @@ -209,8 +208,8 @@ impl Tensor { Self::zeros_impl(shape, dtype, device, true) } - /// Create a new tensors filled with ones with same shape, dtype, and device - /// as the other tensors + /// Creates a new tensor filled with ones with same shape, dtype, and device as the other + /// tensor. /// /// ```rust /// use candle::{Tensor, DType, Device}; @@ -223,7 +222,7 @@ impl Tensor { Tensor::zeros(self.shape(), self.dtype(), &self.device()) } - fn rand_uniform_impl>( + fn rand_impl>( s: S, dtype: DType, device: &Device, @@ -236,27 +235,28 @@ impl Tensor { Ok(from_storage(storage, s, None, is_variable)) } - pub fn rand_uniform>( + /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. + pub fn rand>( s: S, dtype: DType, device: &Device, lo: f64, up: f64, ) -> Result { - Self::rand_uniform_impl(s, dtype, device, lo, up, false) + Self::rand_impl(s, dtype, device, lo, up, false) } - pub fn rand_uniform_var>( + pub fn rand_var>( s: S, dtype: DType, device: &Device, lo: f64, up: f64, ) -> Result { - Self::rand_uniform_impl(s, dtype, device, lo, up, true) + Self::rand_impl(s, dtype, device, lo, up, true) } - fn rand_normal_impl>( + fn randn_impl>( s: S, dtype: DType, device: &Device, @@ -269,24 +269,26 @@ impl Tensor { Ok(from_storage(storage, s, None, is_variable)) } - pub fn rand_normal>( + /// Creates a new tensor initialized with values sampled from a normal distribution with the + /// specified `mean` and standard deviation `std`. + pub fn randn>( s: S, dtype: DType, device: &Device, mean: f64, std: f64, ) -> Result { - Self::rand_normal_impl(s, dtype, device, mean, std, false) + Self::randn_impl(s, dtype, device, mean, std, false) } - pub fn rand_normal_var>( + pub fn randn_var>( s: S, dtype: DType, device: &Device, mean: f64, std: f64, ) -> Result { - Self::rand_normal_impl(s, dtype, device, mean, std, true) + Self::randn_impl(s, dtype, device, mean, std, true) } pub fn new_impl( @@ -304,17 +306,20 @@ impl Tensor { Ok(from_storage(storage, shape, None, is_variable)) } + /// Creates a new tensor on the specified device using the content and shape of the input. pub fn new(array: A, device: &Device) -> Result { let shape = array.shape()?; Self::new_impl(array, shape, device, false) } + /// Creates a new tensor on the specified device using the content and shape of the input. + /// This is similar to `new` but the resulting tensor is a variable. pub fn var(array: A, device: &Device) -> Result { let shape = array.shape()?; Self::new_impl(array, shape, device, true) } - /// Create a new 1D tensor from an iterator. + /// Creates a new 1D tensor from an iterator. pub fn from_iter( iter: impl IntoIterator, device: &Device, @@ -324,13 +329,13 @@ impl Tensor { Self::from_vec_impl(data, len, device, false) } - /// Create a new 1D tensor with values from the interval `[start, end)` taken with a common + /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common /// difference `1` from `start`. pub fn arange(start: D, end: D, device: &Device) -> Result { Self::arange_step(start, end, D::one(), device) } - /// Create a new 1D tensor with values from the interval `[start, end)` taken with a common + /// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common /// difference `step` from `start`. pub fn arange_step( start: D, @@ -363,6 +368,9 @@ impl Tensor { Ok(from_storage(storage, shape, None, is_variable)) } + /// Creates a new tensor initialized with values from the input vector. The number of elements + /// in this vector must be the same as the number of elements defined by the shape. + /// If the device is cpu, no data copy is made. pub fn from_vec, D: crate::WithDType>( data: Vec, shape: S, @@ -379,6 +387,8 @@ impl Tensor { Self::from_vec_impl(data, shape, device, true) } + /// Creates a new tensor initialized with values from the input slice. The number of elements + /// in this vector must be the same as the number of elements defined by the shape. pub fn from_slice, D: crate::WithDType>( array: &[D], shape: S, @@ -478,6 +488,8 @@ impl Tensor { unary_op!(gelu, Gelu); unary_op!(relu, Relu); + /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple + /// dimensions, an error is returned instead. pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims { @@ -496,6 +508,17 @@ impl Tensor { } } + /// This operation multiplies the input tensor by `mul` then adds `add` and return the result. + /// The input values `mul` and `add` are casted to the appropriate type so some rounding might + /// be performed. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; + /// let a = a.affine(4., -2.)?; + /// assert_eq!(a.to_vec2::()?, &[[-2.0, 2.0], [6.0, 10.0]]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn affine(&self, mul: f64, add: f64) -> Result { let storage = self.storage.affine(self.layout(), mul, add)?; let op = if self.track_op() { @@ -510,6 +533,7 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor. pub fn elu(&self, alpha: f64) -> Result { let storage = self.storage.elu(self.layout(), alpha)?; let op = if self.track_op() { @@ -566,6 +590,21 @@ impl Tensor { } } + /// Applies the softmax function to the input tensor, rescaling the element so that elements on + /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; + /// let a = a.softmax(1)?; + /// assert_eq!( + /// a.to_vec2::()?, + /// &[ + /// [0.13447072, 0.3655293, 0.13447072, 0.3655293], + /// [0.004892866, 0.26714143, 0.7261657, 0.0017999847], + /// ]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn softmax(&self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "softmax")?; // TODO: unify the two branches. @@ -589,6 +628,23 @@ impl Tensor { } } + /// Returns the sum of all elements in the input tensor. The sum is performed over all the + /// input dimensions. + /// + /// The resulting tensor as a shape that is similar to the shape of the input tensor, except + /// that the number of elements for each dimension index in `sum_dims` is 1. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; + /// let s = a.sum(&[0])?; + /// assert_eq!(s.to_vec2::()?, &[[2., 4.]]); + /// let s = a.sum(&[1])?; + /// assert_eq!(s.to_vec2::()?, &[[1.], [5.]]); + /// let s = a.sum(&[0, 1])?; + /// assert_eq!(s.to_vec2::()?, &[[6.]]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn sum(&self, sum_dims: &[usize]) -> Result { for &dim in sum_dims { self.check_dim(dim, "sum")?; @@ -606,6 +662,7 @@ impl Tensor { Ok(from_storage(storage, dims, op, false)) } + /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { let (c_out, c_in_k, k_size) = kernel.shape().r3()?; let (b_size, c_in, l_in) = match *self.dims() { @@ -654,6 +711,14 @@ impl Tensor { Ok(from_storage(storage, out_dims, op, false)) } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. + /// + /// # Arguments + /// + /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`. + /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`. + /// + /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`. pub fn matmul(&self, rhs: &Self) -> Result { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); @@ -698,6 +763,9 @@ impl Tensor { Ok(from_storage(storage, c_shape, op, false)) } + /// Returns a tensor with the same shape as the input tensor, the values are taken from + /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the + /// input tensor is equal to zero. pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result { let _shap = self.same_shape_binary_op(on_true, "where_cond")?; let shape = self.same_shape_binary_op(on_false, "where_cond")?; @@ -720,6 +788,25 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + /// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the + /// values hold in the `ids` tensor. + /// + /// # Arguments + /// + /// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive). + /// * `rhs` - A tensor with dimensions `v, h`. + /// + /// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the + /// vocabulary size, and `h` the hidden size. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?; + /// let emb = Tensor::embedding(&ids, &rhs)?; + /// assert_eq!(emb.to_vec2::()?, &[[4., 5.], [2., 3.], [4., 5.]]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn embedding(ids: &Self, rhs: &Self) -> Result { if !rhs.is_contiguous() { return Err(Error::RequiresContiguous { op: "embedding" }); @@ -766,6 +853,7 @@ impl Tensor { } } + /// Returns the data contained in a 1D tensor as a vector of scalar values. pub fn to_vec1(&self) -> Result> { if self.rank() != 1 { return Err(Error::UnexpectedNumberOfDims { @@ -788,6 +876,7 @@ impl Tensor { } } + /// Returns the data contained in a 2D tensor as a vector of vector of scalar values. pub fn to_vec2(&self) -> Result>> { let (dim1, dim2) = self.shape().r2()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { @@ -807,6 +896,7 @@ impl Tensor { } } + /// Returns the data contained in a 3D tensor. pub fn to_vec3(&self) -> Result>>> { let (dim1, dim2, dim3) = self.shape().r3()?; let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { @@ -830,27 +920,34 @@ impl Tensor { } } + /// The dtype for the elements stored in the input tensor. pub fn dtype(&self) -> DType { self.storage.dtype() } + /// The device on which the input tensor is located. pub fn device(&self) -> Device { self.storage.device() } + /// The tensor shape, i.e. dimension sizes on each axis. pub fn shape(&self) -> &Shape { self.layout().shape() } + /// The dimension size for this tensor on each axis. pub fn dims(&self) -> &[usize] { self.shape().dims() } + /// The dimension size for a specified dimension index. pub fn dim(&self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "dim")?; Ok(self.dims()[dim]) } + /// The layout of the input tensor, this stores both the shape of the tensor as well as the + /// strides and the start offset to apply to the underlying storage. pub fn layout(&self) -> &Layout { &self.layout } @@ -859,18 +956,23 @@ impl Tensor { self.layout.stride() } + /// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc. pub fn rank(&self) -> usize { self.shape().rank() } + /// The number of elements stored in this tensor. pub fn elem_count(&self) -> usize { self.shape().elem_count() } + /// The unique identifier for this tensor. pub fn id(&self) -> TensorId { self.id } + /// Whether this tensor is a variable or not. A variable is a tensor for which gradient is + /// tracked and on which backpropagation can be performed. pub fn is_variable(&self) -> bool { self.is_variable } @@ -879,9 +981,19 @@ impl Tensor { &self.op } + /// Computes the sum of all the elements in this tensor and returns a tensor holding this + /// scalar with zero dimensions. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.sum_all()?; + /// assert_eq!(tensor.to_scalar::()?, 15.); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn sum_all(&self) -> Result { let dims: Vec<_> = (0..self.rank()).collect(); - self.sum(&dims) + self.sum(&dims)?.reshape(()) } fn flatten_( @@ -914,22 +1026,47 @@ impl Tensor { } } + /// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both + /// inclusive). pub fn flatten(&self, start_dim: D1, end_dim: D2) -> Result { self.flatten_(Some(start_dim), Some(end_dim)) } + /// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive). pub fn flatten_to(&self, end_dim: D) -> Result { self.flatten_(None::, Some(end_dim)) } + /// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last + /// dimension. pub fn flatten_from(&self, start_dim: D) -> Result { self.flatten_(Some(start_dim), None::) } + /// Flattens the input tensor by reshaping it into a one dimension tensor. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.flatten_all()?; + /// assert_eq!(tensor.to_vec1::()?, &[0., 1., 2., 3., 4., 5.]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn flatten_all(&self) -> Result { self.flatten_(None::, None::) } + /// Returns the sub-tensor fixing the index at `i` on the first dimension. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let t = tensor.get(0)?; + /// assert_eq!(t.to_vec1::()?, &[0., 1.]); + /// let t = tensor.get(1)?; + /// assert_eq!(t.to_vec1::()?, &[2., 3.]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn get(&self, i: usize) -> Result { let dims = self.dims(); if dims.is_empty() { @@ -941,6 +1078,14 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?; + /// let tensor = tensor.t()?; + /// assert_eq!(tensor.to_vec2::()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn t(&self) -> Result { let rank = self.rank(); if rank < 2 { @@ -997,7 +1142,7 @@ impl Tensor { } /// Returns a new tensor detached from the current graph, gradient are not propagated through - /// this new node. + /// this new node. The storage of this tensor is shared with the initial tensor. pub fn detach(&self) -> Result { let tensor_ = Tensor_ { id: TensorId::new(), @@ -1052,6 +1197,13 @@ impl Tensor { self.broadcast_as(dims) } + /// Broadcast the input tensor to the target shape. This returns an error if the input shape is + /// not compatible with the target shape. + /// + /// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or + /// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have + /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If + /// `i_a` is equal to 1, any value can be used. pub fn broadcast_as>(&self, shape: S) -> Result { let op = if self.track_op() { Some(Op::Broadcast(self.clone())) @@ -1073,6 +1225,16 @@ impl Tensor { self.broadcast_as(shape) } + /// Casts the input tensor to the target `dtype`. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?; + /// assert_eq!(tensor.to_scalar::()?, 3.14159265358979); + /// let tensor = tensor.to_dtype(candle::DType::F32)?; + /// assert_eq!(tensor.to_scalar::()?, 3.1415927); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn to_dtype(&self, dtype: DType) -> Result { if self.dtype() == dtype { Ok(self.clone()) @@ -1088,6 +1250,8 @@ impl Tensor { } } + /// Returns a tensor that is in row major order. This is the same as the original tensor if it + /// was already contiguous, otherwise a copy is triggered. pub fn contiguous(&self) -> Result { if self.is_contiguous() { Ok(self.clone()) @@ -1153,7 +1317,7 @@ impl Tensor { } } - /// Removes an extra rank on the tensor with dimension 1. + /// Creates a new tensor with the specified dimension removed if its size was one. /// /// ```rust /// # use candle::{Tensor, DType, Device, D}; @@ -1180,7 +1344,7 @@ impl Tensor { } } - /// Creates an extra rank on the tensor with dimension 1. + /// Creates a new tensor with a dimension of size one inserted at the specified position. /// /// ```rust /// # use candle::{Tensor, DType, Device, D}; @@ -1203,8 +1367,7 @@ impl Tensor { /// Stacks two or more tensors along a particular dimension. /// - /// All tensors must have the same rank, and the output has - /// 1 additional rank + /// All tensors must have the same rank, and the output has one additional rank /// /// ```rust /// # use candle::{Tensor, DType, Device};