Adding some doc + Extended stack to work with extra final dimensions.

This commit is contained in:
Nicolas Patry
2023-07-10 14:51:10 +02:00
parent 204618b7d3
commit 38ac50eeda
2 changed files with 136 additions and 11 deletions

View File

@ -185,6 +185,7 @@ impl Shape {
pub trait Dim {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
}
impl Dim for usize {
@ -200,6 +201,19 @@ impl Dim for usize {
Ok(dim)
}
}
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let dim = *self;
if dim > shape.dims().len() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
dim,
op,
})?
} else {
Ok(dim)
}
}
}
pub enum D {
@ -220,6 +234,19 @@ impl Dim for D {
}),
}
}
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let rank = shape.rank();
match self {
Self::Minus1 if rank >= 1 => Ok(rank),
Self::Minus2 if rank >= 2 => Ok(rank - 1),
_ => Err(Error::DimOutOfRange {
shape: shape.clone(),
dim: 42, // TODO: Have an adequate error
op,
}),
}
}
}
#[cfg(test)]

View File

@ -1,3 +1,4 @@
// #![deny(missing_docs)]
use crate::shape::Dim;
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::Arc;
@ -33,6 +34,19 @@ impl AsRef<Tensor> for Tensor {
// Storages are also refcounted independently so that its possible to avoid
// copying the storage for operations that only modify the shape or stride.
#[derive(Clone)]
/// The core struct for manipulating tensors.
///
/// ```rust
/// use candle::{Tensor, DType, Device};
///
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
///
/// let c = a.matmul(&b)?;
/// # Ok::<(), candle::Error>(())
/// ```
///
/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
pub struct Tensor(Arc<Tensor_>);
impl std::ops::Deref for Tensor {
@ -126,20 +140,51 @@ impl Tensor {
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
/// Create a new tensors filled with ones
///
/// ```rust
/// use candle::{Tensor, DType, Device};
/// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
/// // a == b
/// # Ok::<(), candle::Error>(())
/// ```
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::ones_impl(shape, dtype, device, false)
}
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
// Maybe we should allocate some actual storage for vars rather than just using a
// broadcasted scalar?
Self::ones_impl(shape, dtype, device, true)
}
// Hiding it from now, having this functions forces us to have *every* function that creates
// a new tensor potentially `_var` Maybe having something more like `Tensor::ones(..).var()`
// might be easier to check.
// pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
// // Maybe we should allocate some actual storage for vars rather than just using a
// // broadcasted scalar?
// Self::ones_impl(shape, dtype, device, true)
// }
/// Create a new tensors filled with ones with same shape, dtype, and device
/// as the other tensors
///
/// ```rust
/// use candle::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = a.ones_like()?;
/// // b == a + 1
/// # Ok::<(), candle::Error>(())
/// ```
pub fn ones_like(&self) -> Result<Self> {
Tensor::ones(self.shape(), self.dtype(), &self.device())
}
/// Create a new tensors filled with zeros
///
/// ```rust
/// use candle::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
/// // a == b
/// # Ok::<(), candle::Error>(())
/// ```
fn zeros_impl<S: Into<Shape>>(
shape: S,
dtype: DType,
@ -150,14 +195,33 @@ impl Tensor {
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
}
/// Create a new tensors filled with zeros
///
/// ```rust
/// use candle::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
/// // a == b
/// # Ok::<(), candle::Error>(())
/// ```
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, false)
}
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
Self::zeros_impl(shape, dtype, device, true)
}
// pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
// Self::zeros_impl(shape, dtype, device, true)
// }
/// Create a new tensors filled with ones with same shape, dtype, and device
/// as the other tensors
///
/// ```rust
/// use candle::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = a.zeros_like()?;
/// // b is on CPU f32.
/// # Ok::<(), candle::Error>(())
/// ```
pub fn zeros_like(&self) -> Result<Self> {
Tensor::zeros(self.shape(), self.dtype(), &self.device())
}
@ -187,7 +251,7 @@ impl Tensor {
Self::new_impl(array, shape, device, true)
}
pub fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
device: &Device,
@ -986,11 +1050,28 @@ impl Tensor {
self.reshape(dims)
}
/// Stacks two or more tensors along a particular dimension.
///
/// All tensors must have the same rank, and the output has
/// 1 additional rank
///
/// ```rust
/// # use candle::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::stack(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[2, 2, 3]);
///
/// let c = Tensor::stack(&[&a, &b], 2)?;
/// assert_eq!(c.shape().dims(), &[2, 3, 2]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
}
let dim = dim.to_index(args[0].as_ref().shape(), "stack")?;
let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
let args = args
.iter()
.map(|t| t.as_ref().unsqueeze(dim))
@ -998,6 +1079,23 @@ impl Tensor {
Self::cat(&args, dim)
}
/// Concatenates two or more tensors along a particular dimension.
///
/// All tensors must of the same rank, and the output will have
/// the same rank
///
/// ```rust
/// # use candle::{Tensor, DType, Device};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
///
/// let c = Tensor::cat(&[&a, &b], 0)?;
/// assert_eq!(c.shape().dims(), &[4, 3]);
///
/// let c = Tensor::cat(&[&a, &b], 1)?;
/// assert_eq!(c.shape().dims(), &[2, 6]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
@ -1024,7 +1122,7 @@ impl Tensor {
}
}
pub fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
if args.is_empty() {
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
}