diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 1152dc3e..632ef116 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -185,6 +185,7 @@ impl Shape { pub trait Dim { fn to_index(&self, shape: &Shape, op: &'static str) -> Result; + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result; } impl Dim for usize { @@ -200,6 +201,19 @@ impl Dim for usize { Ok(dim) } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result { + let dim = *self; + if dim > shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + })? + } else { + Ok(dim) + } + } } pub enum D { @@ -220,6 +234,19 @@ impl Dim for D { }), } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank), + Self::Minus2 if rank >= 2 => Ok(rank - 1), + _ => Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: 42, // TODO: Have an adequate error + op, + }), + } + } } #[cfg(test)] diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f0ce18f9..cf8d01e6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,3 +1,4 @@ +// #![deny(missing_docs)] use crate::shape::Dim; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; @@ -33,6 +34,19 @@ impl AsRef for Tensor { // Storages are also refcounted independently so that its possible to avoid // copying the storage for operations that only modify the shape or stride. #[derive(Clone)] +/// The core struct for manipulating tensors. +/// +/// ```rust +/// use candle::{Tensor, DType, Device}; +/// +/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; +/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?; +/// +/// let c = a.matmul(&b)?; +/// # Ok::<(), candle::Error>(()) +/// ``` +/// +/// Tensors are reference counted with [`Arc`] so cloning them is cheap. pub struct Tensor(Arc); impl std::ops::Deref for Tensor { @@ -126,20 +140,51 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with ones + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones>(shape: S, dtype: DType, device: &Device) -> Result { Self::ones_impl(shape, dtype, device, false) } - pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { - // Maybe we should allocate some actual storage for vars rather than just using a - // broadcasted scalar? - Self::ones_impl(shape, dtype, device, true) - } + // Hiding it from now, having this functions forces us to have *every* function that creates + // a new tensor potentially `_var` Maybe having something more like `Tensor::ones(..).var()` + // might be easier to check. + // pub fn ones_var>(shape: S, dtype: DType, device: &Device) -> Result { + // // Maybe we should allocate some actual storage for vars rather than just using a + // // broadcasted scalar? + // Self::ones_impl(shape, dtype, device, true) + // } + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.ones_like()?; + /// // b == a + 1 + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn ones_like(&self) -> Result { Tensor::ones(self.shape(), self.dtype(), &self.device()) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` fn zeros_impl>( shape: S, dtype: DType, @@ -150,14 +195,33 @@ impl Tensor { from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) } + /// Create a new tensors filled with zeros + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?; + /// // a == b + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros>(shape: S, dtype: DType, device: &Device) -> Result { Self::zeros_impl(shape, dtype, device, false) } - pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { - Self::zeros_impl(shape, dtype, device, true) - } + // pub fn zeros_var>(shape: S, dtype: DType, device: &Device) -> Result { + // Self::zeros_impl(shape, dtype, device, true) + // } + /// Create a new tensors filled with ones with same shape, dtype, and device + /// as the other tensors + /// + /// ```rust + /// use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = a.zeros_like()?; + /// // b is on CPU f32. + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn zeros_like(&self) -> Result { Tensor::zeros(self.shape(), self.dtype(), &self.device()) } @@ -187,7 +251,7 @@ impl Tensor { Self::new_impl(array, shape, device, true) } - pub fn from_vec_impl, D: crate::WithDType>( + fn from_vec_impl, D: crate::WithDType>( data: Vec, shape: S, device: &Device, @@ -986,11 +1050,28 @@ impl Tensor { self.reshape(dims) } + /// Stacks two or more tensors along a particular dimension. + /// + /// All tensors must have the same rank, and the output has + /// 1 additional rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::stack(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[2, 2, 3]); + /// + /// let c = Tensor::stack(&[&a, &b], 2)?; + /// assert_eq!(c.shape().dims(), &[2, 3, 2]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn stack, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" }); } - let dim = dim.to_index(args[0].as_ref().shape(), "stack")?; + let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?; let args = args .iter() .map(|t| t.as_ref().unsqueeze(dim)) @@ -998,6 +1079,23 @@ impl Tensor { Self::cat(&args, dim) } + /// Concatenates two or more tensors along a particular dimension. + /// + /// All tensors must of the same rank, and the output will have + /// the same rank + /// + /// ```rust + /// # use candle::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::cat(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[4, 3]); + /// + /// let c = Tensor::cat(&[&a, &b], 1)?; + /// assert_eq!(c.shape().dims(), &[2, 6]); + /// # Ok::<(), candle::Error>(()) + /// ``` pub fn cat, D: Dim>(args: &[A], dim: D) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); @@ -1024,7 +1122,7 @@ impl Tensor { } } - pub fn cat0>(args: &[A]) -> Result { + fn cat0>(args: &[A]) -> Result { if args.is_empty() { return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }); }