mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -572,7 +572,7 @@ impl Tensor {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
// the operation.
|
||||
let exp = self.exp()?;
|
||||
let sum_exp = exp.sum(&[dim])?;
|
||||
let sum_exp = exp.sum_keepdim(&[dim])?;
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
@ -591,21 +591,21 @@ impl Tensor {
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
/// The resulting tensor as a shape that is similar to the shape of the input tensor, except
|
||||
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let s = a.sum(&[0])?;
|
||||
/// let s = a.sum_keepdim(&[0])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
||||
/// let s = a.sum(&[1])?;
|
||||
/// let s = a.sum_keepdim(&[1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
||||
/// let s = a.sum(&[0, 1])?;
|
||||
/// let s = a.sum_keepdim(&[0, 1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
}
|
||||
@ -622,6 +622,32 @@ impl Tensor {
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
|
||||
/// kept.
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
let sum = self.sum_keepdim(sum_dims)?;
|
||||
match sum_dims {
|
||||
[] => Ok(sum),
|
||||
[i] => sum.squeeze(*i),
|
||||
sum_dims => {
|
||||
let dims = sum
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(dim_idx, &v)| {
|
||||
if sum_dims.contains(&dim_idx) {
|
||||
None
|
||||
} else {
|
||||
Some(v)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
sum.reshape(dims)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
@ -936,7 +962,7 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn sum_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.sum(&dims)?.reshape(())
|
||||
self.sum_keepdim(&dims)?.reshape(())
|
||||
}
|
||||
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
|
Reference in New Issue
Block a user