Use the same default as pytorch for sum. (#164)

This commit is contained in:
Laurent Mazare
2023-07-13 21:32:32 +01:00
committed by GitHub
parent 57be3638d8
commit 2bfa791336
13 changed files with 123 additions and 56 deletions

View File

@ -572,7 +572,7 @@ impl Tensor {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
// the operation.
let exp = self.exp()?;
let sum_exp = exp.sum(&[dim])?;
let sum_exp = exp.sum_keepdim(&[dim])?;
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
@ -591,21 +591,21 @@ impl Tensor {
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions.
///
/// The resulting tensor as a shape that is similar to the shape of the input tensor, except
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
/// that the number of elements for each dimension index in `sum_dims` is 1.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let s = a.sum(&[0])?;
/// let s = a.sum_keepdim(&[0])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
/// let s = a.sum(&[1])?;
/// let s = a.sum_keepdim(&[1])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
/// let s = a.sum(&[0, 1])?;
/// let s = a.sum_keepdim(&[0, 1])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
for &dim in sum_dims {
self.check_dim(dim, "sum")?;
}
@ -622,6 +622,32 @@ impl Tensor {
Ok(from_storage(storage, dims, op, false))
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
/// kept.
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
let sum = self.sum_keepdim(sum_dims)?;
match sum_dims {
[] => Ok(sum),
[i] => sum.squeeze(*i),
sum_dims => {
let dims = sum
.dims()
.iter()
.enumerate()
.filter_map(|(dim_idx, &v)| {
if sum_dims.contains(&dim_idx) {
None
} else {
Some(v)
}
})
.collect::<Vec<_>>();
sum.reshape(dims)
}
}
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
@ -936,7 +962,7 @@ impl Tensor {
/// ```
pub fn sum_all(&self) -> Result<Tensor> {
let dims: Vec<_> = (0..self.rank()).collect();
self.sum(&dims)?.reshape(())
self.sum_keepdim(&dims)?.reshape(())
}
fn flatten_<D1: Dim, D2: Dim>(