mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Adding some doc + Extended stack
to work with extra final dimensions.
This commit is contained in:
@ -185,6 +185,7 @@ impl Shape {
|
||||
|
||||
pub trait Dim {
|
||||
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
|
||||
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
|
||||
}
|
||||
|
||||
impl Dim for usize {
|
||||
@ -200,6 +201,19 @@ impl Dim for usize {
|
||||
Ok(dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
||||
let dim = *self;
|
||||
if dim > shape.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
dim,
|
||||
op,
|
||||
})?
|
||||
} else {
|
||||
Ok(dim)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum D {
|
||||
@ -220,6 +234,19 @@ impl Dim for D {
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
|
||||
let rank = shape.rank();
|
||||
match self {
|
||||
Self::Minus1 if rank >= 1 => Ok(rank),
|
||||
Self::Minus2 if rank >= 2 => Ok(rank - 1),
|
||||
_ => Err(Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
dim: 42, // TODO: Have an adequate error
|
||||
op,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1,3 +1,4 @@
|
||||
// #![deny(missing_docs)]
|
||||
use crate::shape::Dim;
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
@ -33,6 +34,19 @@ impl AsRef<Tensor> for Tensor {
|
||||
// Storages are also refcounted independently so that its possible to avoid
|
||||
// copying the storage for operations that only modify the shape or stride.
|
||||
#[derive(Clone)]
|
||||
/// The core struct for manipulating tensors.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
///
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = a.matmul(&b)?;
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
///
|
||||
/// Tensors are reference counted with [`Arc`] so cloning them is cheap.
|
||||
pub struct Tensor(Arc<Tensor_>);
|
||||
|
||||
impl std::ops::Deref for Tensor {
|
||||
@ -126,20 +140,51 @@ impl Tensor {
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with ones
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
/// let a = Tensor::ones((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::from_slice(&[1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0], (2, 3), &Device::Cpu)?;
|
||||
/// // a == b
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
Self::ones_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
// Maybe we should allocate some actual storage for vars rather than just using a
|
||||
// broadcasted scalar?
|
||||
Self::ones_impl(shape, dtype, device, true)
|
||||
}
|
||||
// Hiding it from now, having this functions forces us to have *every* function that creates
|
||||
// a new tensor potentially `_var` Maybe having something more like `Tensor::ones(..).var()`
|
||||
// might be easier to check.
|
||||
// pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
// // Maybe we should allocate some actual storage for vars rather than just using a
|
||||
// // broadcasted scalar?
|
||||
// Self::ones_impl(shape, dtype, device, true)
|
||||
// }
|
||||
|
||||
/// Create a new tensors filled with ones with same shape, dtype, and device
|
||||
/// as the other tensors
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = a.ones_like()?;
|
||||
/// // b == a + 1
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn ones_like(&self) -> Result<Self> {
|
||||
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with zeros
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
|
||||
/// // a == b
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
fn zeros_impl<S: Into<Shape>>(
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
@ -150,14 +195,33 @@ impl Tensor {
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with zeros
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0], (2, 3), &Device::Cpu)?;
|
||||
/// // a == b
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
Self::zeros_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
Self::zeros_impl(shape, dtype, device, true)
|
||||
}
|
||||
// pub fn zeros_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
// Self::zeros_impl(shape, dtype, device, true)
|
||||
// }
|
||||
|
||||
/// Create a new tensors filled with ones with same shape, dtype, and device
|
||||
/// as the other tensors
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = a.zeros_like()?;
|
||||
/// // b is on CPU f32.
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn zeros_like(&self) -> Result<Self> {
|
||||
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
@ -187,7 +251,7 @@ impl Tensor {
|
||||
Self::new_impl(array, shape, device, true)
|
||||
}
|
||||
|
||||
pub fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
@ -986,11 +1050,28 @@ impl Tensor {
|
||||
self.reshape(dims)
|
||||
}
|
||||
|
||||
/// Stacks two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must have the same rank, and the output has
|
||||
/// 1 additional rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::stack(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 2, 3]);
|
||||
///
|
||||
/// let c = Tensor::stack(&[&a, &b], 2)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 3, 2]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn stack<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "stack" });
|
||||
}
|
||||
let dim = dim.to_index(args[0].as_ref().shape(), "stack")?;
|
||||
let dim = dim.to_index_plus_one(args[0].as_ref().shape(), "stack")?;
|
||||
let args = args
|
||||
.iter()
|
||||
.map(|t| t.as_ref().unsqueeze(dim))
|
||||
@ -998,6 +1079,23 @@ impl Tensor {
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
@ -1024,7 +1122,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||
}
|
||||
|
Reference in New Issue
Block a user