mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add more documentation and examples. (#149)
* Add more documentation and examples. * More documentation and tests. * Document more tensor functions. * Again more examples and tests.
This commit is contained in:
@ -140,7 +140,7 @@ impl Tensor {
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with ones
|
||||
/// Creates a new tensor filled with ones.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
@ -159,8 +159,7 @@ impl Tensor {
|
||||
Self::ones_impl(shape, dtype, device, true)
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with ones with same shape, dtype, and device
|
||||
/// as the other tensors
|
||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
@ -173,7 +172,7 @@ impl Tensor {
|
||||
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with zeros
|
||||
/// Creates a new tensor filled with zeros.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
@ -192,7 +191,7 @@ impl Tensor {
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with zeros
|
||||
/// Creates a new tensor filled with zeros.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
@ -209,8 +208,8 @@ impl Tensor {
|
||||
Self::zeros_impl(shape, dtype, device, true)
|
||||
}
|
||||
|
||||
/// Create a new tensors filled with ones with same shape, dtype, and device
|
||||
/// as the other tensors
|
||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other
|
||||
/// tensor.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
@ -223,7 +222,7 @@ impl Tensor {
|
||||
Tensor::zeros(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
|
||||
fn rand_uniform_impl<S: Into<Shape>>(
|
||||
fn rand_impl<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
@ -236,27 +235,28 @@ impl Tensor {
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
}
|
||||
|
||||
pub fn rand_uniform<S: Into<Shape>>(
|
||||
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
||||
pub fn rand<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_uniform_impl(s, dtype, device, lo, up, false)
|
||||
Self::rand_impl(s, dtype, device, lo, up, false)
|
||||
}
|
||||
|
||||
pub fn rand_uniform_var<S: Into<Shape>>(
|
||||
pub fn rand_var<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
lo: f64,
|
||||
up: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_uniform_impl(s, dtype, device, lo, up, true)
|
||||
Self::rand_impl(s, dtype, device, lo, up, true)
|
||||
}
|
||||
|
||||
fn rand_normal_impl<S: Into<Shape>>(
|
||||
fn randn_impl<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
@ -269,24 +269,26 @@ impl Tensor {
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
}
|
||||
|
||||
pub fn rand_normal<S: Into<Shape>>(
|
||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||
/// specified `mean` and standard deviation `std`.
|
||||
pub fn randn<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_normal_impl(s, dtype, device, mean, std, false)
|
||||
Self::randn_impl(s, dtype, device, mean, std, false)
|
||||
}
|
||||
|
||||
pub fn rand_normal_var<S: Into<Shape>>(
|
||||
pub fn randn_var<S: Into<Shape>>(
|
||||
s: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
mean: f64,
|
||||
std: f64,
|
||||
) -> Result<Self> {
|
||||
Self::rand_normal_impl(s, dtype, device, mean, std, true)
|
||||
Self::randn_impl(s, dtype, device, mean, std, true)
|
||||
}
|
||||
|
||||
pub fn new_impl<A: crate::device::NdArray>(
|
||||
@ -304,17 +306,20 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor on the specified device using the content and shape of the input.
|
||||
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||
let shape = array.shape()?;
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
/// Creates a new tensor on the specified device using the content and shape of the input.
|
||||
/// This is similar to `new` but the resulting tensor is a variable.
|
||||
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||
let shape = array.shape()?;
|
||||
Self::new_impl(array, shape, device, true)
|
||||
}
|
||||
|
||||
/// Create a new 1D tensor from an iterator.
|
||||
/// Creates a new 1D tensor from an iterator.
|
||||
pub fn from_iter<D: crate::WithDType>(
|
||||
iter: impl IntoIterator<Item = D>,
|
||||
device: &Device,
|
||||
@ -324,13 +329,13 @@ impl Tensor {
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
}
|
||||
|
||||
/// Create a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// difference `1` from `start`.
|
||||
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
||||
Self::arange_step(start, end, D::one(), device)
|
||||
}
|
||||
|
||||
/// Create a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||
/// difference `step` from `start`.
|
||||
pub fn arange_step<D: crate::WithDType>(
|
||||
start: D,
|
||||
@ -363,6 +368,9 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
||||
/// in this vector must be the same as the number of elements defined by the shape.
|
||||
/// If the device is cpu, no data copy is made.
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
@ -379,6 +387,8 @@ impl Tensor {
|
||||
Self::from_vec_impl(data, shape, device, true)
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values from the input slice. The number of elements
|
||||
/// in this vector must be the same as the number of elements defined by the shape.
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
@ -478,6 +488,8 @@ impl Tensor {
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(relu, Relu);
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
/// dimensions, an error is returned instead.
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
@ -496,6 +508,17 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
|
||||
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
|
||||
/// be performed.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let a = a.affine(4., -2.)?;
|
||||
/// assert_eq!(a.to_vec2::<f32>()?, &[[-2.0, 2.0], [6.0, 10.0]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
let storage = self.storage.affine(self.layout(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
@ -510,6 +533,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
let storage = self.storage.elu(self.layout(), alpha)?;
|
||||
let op = if self.track_op() {
|
||||
@ -566,6 +590,21 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
|
||||
/// let a = a.softmax(1)?;
|
||||
/// assert_eq!(
|
||||
/// a.to_vec2::<f32>()?,
|
||||
/// &[
|
||||
/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
|
||||
/// [0.004892866, 0.26714143, 0.7261657, 0.0017999847],
|
||||
/// ]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "softmax")?;
|
||||
// TODO: unify the two branches.
|
||||
@ -589,6 +628,23 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
/// The resulting tensor as a shape that is similar to the shape of the input tensor, except
|
||||
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let s = a.sum(&[0])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
||||
/// let s = a.sum(&[1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
||||
/// let s = a.sum(&[0, 1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
@ -606,6 +662,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
let (b_size, c_in, l_in) = match *self.dims() {
|
||||
@ -654,6 +711,14 @@ impl Tensor {
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
|
||||
/// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
|
||||
///
|
||||
/// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
|
||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let a_dims = self.shape().dims();
|
||||
let b_dims = rhs.shape().dims();
|
||||
@ -698,6 +763,9 @@ impl Tensor {
|
||||
Ok(from_storage(storage, c_shape, op, false))
|
||||
}
|
||||
|
||||
/// Returns a tensor with the same shape as the input tensor, the values are taken from
|
||||
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
|
||||
/// input tensor is equal to zero.
|
||||
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
|
||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||
@ -720,6 +788,25 @@ impl Tensor {
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Returns a tensor with the values from the `rhs` tensor at the index corresponding to the
|
||||
/// values hold in the `ids` tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ids` - A tensor with dimensions `s` and with integer values between 0 and v (exclusive).
|
||||
/// * `rhs` - A tensor with dimensions `v, h`.
|
||||
///
|
||||
/// The resulting tensor has dimensions `s, h`. `s` is called the sequence length, `v` the
|
||||
/// vocabulary size, and `h` the hidden size.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let rhs = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let ids = Tensor::new(&[2u32, 1u32, 2u32], &Device::Cpu)?;
|
||||
/// let emb = Tensor::embedding(&ids, &rhs)?;
|
||||
/// assert_eq!(emb.to_vec2::<f32>()?, &[[4., 5.], [2., 3.], [4., 5.]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||
if !rhs.is_contiguous() {
|
||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||
@ -766,6 +853,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the data contained in a 1D tensor as a vector of scalar values.
|
||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||
if self.rank() != 1 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
@ -788,6 +876,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
|
||||
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
||||
let (dim1, dim2) = self.shape().r2()?;
|
||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||
@ -807,6 +896,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the data contained in a 3D tensor.
|
||||
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
|
||||
let (dim1, dim2, dim3) = self.shape().r3()?;
|
||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||
@ -830,27 +920,34 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// The dtype for the elements stored in the input tensor.
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.storage.dtype()
|
||||
}
|
||||
|
||||
/// The device on which the input tensor is located.
|
||||
pub fn device(&self) -> Device {
|
||||
self.storage.device()
|
||||
}
|
||||
|
||||
/// The tensor shape, i.e. dimension sizes on each axis.
|
||||
pub fn shape(&self) -> &Shape {
|
||||
self.layout().shape()
|
||||
}
|
||||
|
||||
/// The dimension size for this tensor on each axis.
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
self.shape().dims()
|
||||
}
|
||||
|
||||
/// The dimension size for a specified dimension index.
|
||||
pub fn dim<D: Dim>(&self, dim: D) -> Result<usize> {
|
||||
let dim = dim.to_index(self.shape(), "dim")?;
|
||||
Ok(self.dims()[dim])
|
||||
}
|
||||
|
||||
/// The layout of the input tensor, this stores both the shape of the tensor as well as the
|
||||
/// strides and the start offset to apply to the underlying storage.
|
||||
pub fn layout(&self) -> &Layout {
|
||||
&self.layout
|
||||
}
|
||||
@ -859,18 +956,23 @@ impl Tensor {
|
||||
self.layout.stride()
|
||||
}
|
||||
|
||||
/// The number of dimensions for this tensor, 0 for a scalar tensor, 1 for a 1D tensor, etc.
|
||||
pub fn rank(&self) -> usize {
|
||||
self.shape().rank()
|
||||
}
|
||||
|
||||
/// The number of elements stored in this tensor.
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.shape().elem_count()
|
||||
}
|
||||
|
||||
/// The unique identifier for this tensor.
|
||||
pub fn id(&self) -> TensorId {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Whether this tensor is a variable or not. A variable is a tensor for which gradient is
|
||||
/// tracked and on which backpropagation can be performed.
|
||||
pub fn is_variable(&self) -> bool {
|
||||
self.is_variable
|
||||
}
|
||||
@ -879,9 +981,19 @@ impl Tensor {
|
||||
&self.op
|
||||
}
|
||||
|
||||
/// Computes the sum of all the elements in this tensor and returns a tensor holding this
|
||||
/// scalar with zero dimensions.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.sum_all()?;
|
||||
/// assert_eq!(tensor.to_scalar::<f32>()?, 15.);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn sum_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.sum(&dims)
|
||||
self.sum(&dims)?.reshape(())
|
||||
}
|
||||
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
@ -914,22 +1026,47 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens the input tensor on the dimension indexes from `start_dim` to `end_dim` (both
|
||||
/// inclusive).
|
||||
pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
|
||||
self.flatten_(Some(start_dim), Some(end_dim))
|
||||
}
|
||||
|
||||
/// Flattens the input tensor on the dimension indexes from `0` to `end_dim` (inclusive).
|
||||
pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
|
||||
self.flatten_(None::<usize>, Some(end_dim))
|
||||
}
|
||||
|
||||
/// Flattens the input tensor on the dimension indexes from `start_dim` (inclusive) to the last
|
||||
/// dimension.
|
||||
pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
|
||||
self.flatten_(Some(start_dim), None::<usize>)
|
||||
}
|
||||
|
||||
/// Flattens the input tensor by reshaping it into a one dimension tensor.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.flatten_all()?;
|
||||
/// assert_eq!(tensor.to_vec1::<f32>()?, &[0., 1., 2., 3., 4., 5.]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn flatten_all(&self) -> Result<Tensor> {
|
||||
self.flatten_(None::<usize>, None::<usize>)
|
||||
}
|
||||
|
||||
/// Returns the sub-tensor fixing the index at `i` on the first dimension.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let t = tensor.get(0)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 1.]);
|
||||
/// let t = tensor.get(1)?;
|
||||
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn get(&self, i: usize) -> Result<Tensor> {
|
||||
let dims = self.dims();
|
||||
if dims.is_empty() {
|
||||
@ -941,6 +1078,14 @@ impl Tensor {
|
||||
|
||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||
/// input are swapped.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
||||
/// let tensor = tensor.t()?;
|
||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[0.0, 2.0, 4.0], [1.0, 3.0, 5.0]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn t(&self) -> Result<Tensor> {
|
||||
let rank = self.rank();
|
||||
if rank < 2 {
|
||||
@ -997,7 +1142,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||
/// this new node.
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1052,6 +1197,13 @@ impl Tensor {
|
||||
self.broadcast_as(dims)
|
||||
}
|
||||
|
||||
/// Broadcast the input tensor to the target shape. This returns an error if the input shape is
|
||||
/// not compatible with the target shape.
|
||||
///
|
||||
/// If the input shape is `i_1, i_2, ... i_k`, the target shape has to have `k` dimensions or
|
||||
/// more and shape `j_1, ..., j_l, t_1, t_2, ..., t_k`. The dimensions `j_1` to `j_l` can have
|
||||
/// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
|
||||
/// `i_a` is equal to 1, any value can be used.
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Broadcast(self.clone()))
|
||||
@ -1073,6 +1225,16 @@ impl Tensor {
|
||||
self.broadcast_as(shape)
|
||||
}
|
||||
|
||||
/// Casts the input tensor to the target `dtype`.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let tensor = Tensor::new(3.14159265358979f64, &Device::Cpu)?;
|
||||
/// assert_eq!(tensor.to_scalar::<f64>()?, 3.14159265358979);
|
||||
/// let tensor = tensor.to_dtype(candle::DType::F32)?;
|
||||
/// assert_eq!(tensor.to_scalar::<f32>()?, 3.1415927);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
||||
if self.dtype() == dtype {
|
||||
Ok(self.clone())
|
||||
@ -1088,6 +1250,8 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a tensor that is in row major order. This is the same as the original tensor if it
|
||||
/// was already contiguous, otherwise a copy is triggered.
|
||||
pub fn contiguous(&self) -> Result<Tensor> {
|
||||
if self.is_contiguous() {
|
||||
Ok(self.clone())
|
||||
@ -1153,7 +1317,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes an extra rank on the tensor with dimension 1.
|
||||
/// Creates a new tensor with the specified dimension removed if its size was one.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device, D};
|
||||
@ -1180,7 +1344,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an extra rank on the tensor with dimension 1.
|
||||
/// Creates a new tensor with a dimension of size one inserted at the specified position.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device, D};
|
||||
@ -1203,8 +1367,7 @@ impl Tensor {
|
||||
|
||||
/// Stacks two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must have the same rank, and the output has
|
||||
/// 1 additional rank
|
||||
/// All tensors must have the same rank, and the output has one additional rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle::{Tensor, DType, Device};
|
||||
|
Reference in New Issue
Block a user