mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -37,6 +37,13 @@ pub enum Error {
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
|
||||
DuplicateDimIndex {
|
||||
shape: Shape,
|
||||
dims: Vec<usize>,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
// === Shape Errors ===
|
||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||
UnexpectedNumberOfDims {
|
||||
|
@ -256,6 +256,86 @@ impl Dim for D {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Dims: Sized {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
|
||||
|
||||
fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let dims = self.to_indexes_internal(shape, op)?;
|
||||
for (i, &dim) in dims.iter().enumerate() {
|
||||
if dims[..i].contains(&dim) {
|
||||
Err(Error::DuplicateDimIndex {
|
||||
shape: shape.clone(),
|
||||
dims: dims.clone(),
|
||||
op,
|
||||
})?
|
||||
}
|
||||
if dim >= shape.rank() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
dim: dim as i32,
|
||||
op,
|
||||
})?
|
||||
}
|
||||
}
|
||||
Ok(dims)
|
||||
}
|
||||
}
|
||||
|
||||
impl Dims for Vec<usize> {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Dims for [usize; N] {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(self.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Dims for &[usize] {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(self.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Dims for () {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Dim + Sized> Dims for D {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let dim = self.to_index(shape, op)?;
|
||||
Ok(vec![dim])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Dim> Dims for (D,) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let dim = self.0.to_index(shape, op)?;
|
||||
Ok(vec![dim])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::shape::Dim;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
@ -572,7 +572,7 @@ impl Tensor {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
// the operation.
|
||||
let exp = self.exp()?;
|
||||
let sum_exp = exp.sum_keepdim(&[dim])?;
|
||||
let sum_exp = exp.sum_keepdim(dim)?;
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
@ -588,28 +588,9 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let s = a.sum_keepdim(&[0])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
||||
/// let s = a.sum_keepdim(&[1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
||||
/// let s = a.sum_keepdim(&[0, 1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
}
|
||||
let storage = self.storage().sum(self.layout(), sum_dims)?;
|
||||
pub fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
|
||||
let storage = self.storage().sum(self.layout(), &sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -619,33 +600,58 @@ impl Tensor {
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dims[sum_dim] = 1
|
||||
}
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
let sum = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(sum)
|
||||
} else {
|
||||
match sum_dims.as_slice() {
|
||||
[] => Ok(sum),
|
||||
[i] => sum.squeeze(*i),
|
||||
sum_dims => {
|
||||
let dims = sum
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(dim_idx, &v)| {
|
||||
if sum_dims.contains(&dim_idx) {
|
||||
None
|
||||
} else {
|
||||
Some(v)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
sum.reshape(dims)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let s = a.sum_keepdim(0)?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
||||
/// let s = a.sum_keepdim(1)?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
||||
/// let s = a.sum_keepdim((0, 1))?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
self.sum_impl(sum_dims, true)
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
|
||||
/// kept.
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
let sum = self.sum_keepdim(sum_dims)?;
|
||||
match sum_dims {
|
||||
[] => Ok(sum),
|
||||
[i] => sum.squeeze(*i),
|
||||
sum_dims => {
|
||||
let dims = sum
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(dim_idx, &v)| {
|
||||
if sum_dims.contains(&dim_idx) {
|
||||
None
|
||||
} else {
|
||||
Some(v)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
sum.reshape(dims)
|
||||
}
|
||||
}
|
||||
pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||
self.sum_impl(sum_dims, false)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
@ -962,7 +968,7 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn sum_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.sum_keepdim(&dims)?.reshape(())
|
||||
self.sum(dims)
|
||||
}
|
||||
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
|
Reference in New Issue
Block a user