mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add a simpler way to specify the dim index for some ops.
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
use crate::shape::Dim;
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -362,9 +363,9 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + len`.
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.dims();
|
||||
self.check_dim(dim, "narrow")?;
|
||||
let dim = dim.to_index(self.shape(), "narrow")?;
|
||||
if start + len > dims[dim] {
|
||||
Err(Error::NarrowInvalidArgs {
|
||||
shape: self.shape().clone(),
|
||||
@ -392,8 +393,8 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||
self.check_dim(dim, "softmax")?;
|
||||
pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "softmax")?;
|
||||
// TODO: unify the two branches.
|
||||
if self.device().is_cuda() {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
@ -692,14 +693,22 @@ impl Tensor {
|
||||
self.sum(&dims)
|
||||
}
|
||||
|
||||
pub fn flatten(&self, start_dim: Option<usize>, end_dim: Option<usize>) -> Result<Tensor> {
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
&self,
|
||||
start_dim: Option<D1>,
|
||||
end_dim: Option<D2>,
|
||||
) -> Result<Tensor> {
|
||||
if self.rank() == 0 {
|
||||
self.reshape(1)
|
||||
} else {
|
||||
let start_dim = start_dim.unwrap_or(0);
|
||||
let end_dim = end_dim.unwrap_or_else(|| self.rank() - 1);
|
||||
self.check_dim(start_dim, "flatten")?;
|
||||
self.check_dim(end_dim, "flatten")?;
|
||||
let start_dim = match start_dim {
|
||||
None => 0,
|
||||
Some(dim) => dim.to_index(self.shape(), "flatten")?,
|
||||
};
|
||||
let end_dim = match end_dim {
|
||||
None => self.rank() - 1,
|
||||
Some(dim) => dim.to_index(self.shape(), "flatten")?,
|
||||
};
|
||||
if start_dim < end_dim {
|
||||
let dims = self.dims();
|
||||
let mut dst_dims = dims[..start_dim].to_vec();
|
||||
@ -714,8 +723,20 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flatten<D1: Dim, D2: Dim>(&self, start_dim: D1, end_dim: D2) -> Result<Tensor> {
|
||||
self.flatten_(Some(start_dim), Some(end_dim))
|
||||
}
|
||||
|
||||
pub fn flatten_to<D: Dim>(&self, end_dim: D) -> Result<Tensor> {
|
||||
self.flatten_(None::<usize>, Some(end_dim))
|
||||
}
|
||||
|
||||
pub fn flatten_from<D: Dim>(&self, start_dim: D) -> Result<Tensor> {
|
||||
self.flatten_(Some(start_dim), None::<usize>)
|
||||
}
|
||||
|
||||
pub fn flatten_all(&self) -> Result<Tensor> {
|
||||
self.flatten(None, None)
|
||||
self.flatten_(None::<usize>, None::<usize>)
|
||||
}
|
||||
|
||||
pub fn get(&self, i: usize) -> Result<Tensor> {
|
||||
@ -743,9 +764,9 @@ impl Tensor {
|
||||
|
||||
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
||||
/// swapped.
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
||||
self.check_dim(dim1, "transpose")?;
|
||||
self.check_dim(dim2, "transpose")?;
|
||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||
} else {
|
||||
@ -929,23 +950,23 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn squeeze(&self, index: usize) -> Result<Self> {
|
||||
pub fn squeeze<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
// The PyTorch semantics are to return the same tensor if the target dimension
|
||||
// does not have a size of 1.
|
||||
let dims = self.dims();
|
||||
self.check_dim(index, "squeeze")?;
|
||||
if dims[index] == 1 {
|
||||
let dim = dim.to_index(self.shape(), "squeeze")?;
|
||||
if dims[dim] == 1 {
|
||||
let mut dims = dims.to_vec();
|
||||
dims.remove(index);
|
||||
dims.remove(dim);
|
||||
self.reshape(dims)
|
||||
} else {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unsqueeze(&self, index: usize) -> Result<Self> {
|
||||
pub fn unsqueeze(&self, dim: usize) -> Result<Self> {
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims.insert(index, 1);
|
||||
dims.insert(dim, 1);
|
||||
self.reshape(dims)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user