mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Use the same default as pytorch for sum. (#164)
This commit is contained in:
@ -7,9 +7,9 @@ use candle::{Device, Tensor};
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
|
||||
let sum = t.sum(&[0])?;
|
||||
let sum = t.sum_keepdim(&[0])?;
|
||||
println!("{sum}");
|
||||
let sum = t.sum(&[1])?;
|
||||
let sum = t.sum_keepdim(&[1])?;
|
||||
println!("{sum}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -27,18 +27,18 @@ fn main() -> Result<()> {
|
||||
let xys_cpu = cos_sin(n, &Device::Cpu)?;
|
||||
let xys = cos_sin(n, &device)?;
|
||||
println!("{xys_cpu:?} {xys:?}");
|
||||
let sum_cpu = xys_cpu.sum(&[1])?;
|
||||
println!("{sum_cpu}");
|
||||
let sum = xys.sum(&[1])?;
|
||||
println!("{sum}");
|
||||
let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?;
|
||||
println!("{sum_keepdim_cpu}");
|
||||
let sum_keepdim = xys.sum_keepdim(&[1])?;
|
||||
println!("{sum_keepdim}");
|
||||
let start = std::time::Instant::now();
|
||||
let n_iters = 100;
|
||||
let mut v = 0f32;
|
||||
for _i in 0..n_iters {
|
||||
let sum = xys.sum(&[1])?;
|
||||
let sum = sum.sum(&[0])?;
|
||||
let sum: f32 = sum.reshape(&[])?.to_scalar()?;
|
||||
v += sum;
|
||||
let sum_keepdim = xys.sum_keepdim(&[1])?;
|
||||
let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?;
|
||||
let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
|
||||
v += sum_keepdim;
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
if v > 0. {
|
||||
|
@ -195,11 +195,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
let mut arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||
// sum_dims has increasing values.
|
||||
for &dim in sum_dims.iter().rev() {
|
||||
arg_grad = arg_grad.squeeze(dim)?
|
||||
}
|
||||
let arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||
}
|
||||
|
@ -572,7 +572,7 @@ impl Tensor {
|
||||
// We do not have a cuda kernel for divide_by_sum_over_dim so split
|
||||
// the operation.
|
||||
let exp = self.exp()?;
|
||||
let sum_exp = exp.sum(&[dim])?;
|
||||
let sum_exp = exp.sum_keepdim(&[dim])?;
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
@ -591,21 +591,21 @@ impl Tensor {
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions.
|
||||
///
|
||||
/// The resulting tensor as a shape that is similar to the shape of the input tensor, except
|
||||
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
|
||||
/// that the number of elements for each dimension index in `sum_dims` is 1.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, Device};
|
||||
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
|
||||
/// let s = a.sum(&[0])?;
|
||||
/// let s = a.sum_keepdim(&[0])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
|
||||
/// let s = a.sum(&[1])?;
|
||||
/// let s = a.sum_keepdim(&[1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
|
||||
/// let s = a.sum(&[0, 1])?;
|
||||
/// let s = a.sum_keepdim(&[0, 1])?;
|
||||
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
for &dim in sum_dims {
|
||||
self.check_dim(dim, "sum")?;
|
||||
}
|
||||
@ -622,6 +622,32 @@ impl Tensor {
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
|
||||
/// kept.
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
let sum = self.sum_keepdim(sum_dims)?;
|
||||
match sum_dims {
|
||||
[] => Ok(sum),
|
||||
[i] => sum.squeeze(*i),
|
||||
sum_dims => {
|
||||
let dims = sum
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(dim_idx, &v)| {
|
||||
if sum_dims.contains(&dim_idx) {
|
||||
None
|
||||
} else {
|
||||
Some(v)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
sum.reshape(dims)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
@ -936,7 +962,7 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn sum_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.sum(&dims)?.reshape(())
|
||||
self.sum_keepdim(&dims)?.reshape(())
|
||||
}
|
||||
|
||||
fn flatten_<D1: Dim, D2: Dim>(
|
||||
|
@ -19,7 +19,7 @@ fn simple_grad(device: &Device) -> Result<()> {
|
||||
fn sum_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
let x = x.as_tensor();
|
||||
let y = (x.sqr()?.sum(&[0])? * 2.)?;
|
||||
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [52.]);
|
||||
@ -27,7 +27,7 @@ fn sum_grad(device: &Device) -> Result<()> {
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
|
||||
|
||||
// Same test as before but squeezing on the last dimension.
|
||||
let y = (x.sqr()?.sum(&[0])? * 2.)?.squeeze(0)?;
|
||||
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?.squeeze(0)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_scalar::<f32>()?, 52.);
|
||||
|
@ -108,56 +108,99 @@ fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.sum(&[2])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?,
|
||||
&[[[8], [15]], [[10], [18]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum(&[0])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
||||
&[[[5, 2, 11], [9, 7, 17]]],
|
||||
);
|
||||
assert_eq!(tensor.sum(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
|
||||
assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
|
||||
assert_eq!(
|
||||
tensor.t()?.sum(&[1])?.t()?.to_vec3::<u32>()?,
|
||||
tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?,
|
||||
&[[[8], [15]], [[10], [18]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum(&[2, 1])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?,
|
||||
&[[[8 + 15]], [[10 + 18]]]
|
||||
);
|
||||
let data: Vec<u32> = (0..4000u32).collect();
|
||||
let tensor = Tensor::new(data.as_slice(), device)?;
|
||||
assert_eq!(tensor.sum(&[0])?.to_vec1::<u32>()?, &[7998000]);
|
||||
assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]);
|
||||
let tensor = tensor.reshape((2000, 2))?;
|
||||
assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
|
||||
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[1])?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[1])?
|
||||
.sum_keepdim(&[0])?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
||||
&[[3998000, 4000000]]
|
||||
);
|
||||
|
||||
// Make the tensor non contiguous.
|
||||
let tensor = tensor.t()?.contiguous()?.t()?;
|
||||
assert_eq!(tensor.sum(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum(&[0])?.sum(&[1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum(&[1])?.sum(&[0])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(tensor.sum(&[0])?.to_vec2::<u32>()?, &[[3998000, 4000000]]);
|
||||
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[1])?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.sum_keepdim(&[1])?
|
||||
.sum_keepdim(&[0])?
|
||||
.to_vec2::<u32>()?,
|
||||
&[[7998000]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
|
||||
&[[3998000, 4000000]]
|
||||
);
|
||||
|
||||
let t1 = tensor.reshape((200, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
assert_eq!(tensor.sum(&[0, 1, 2])?.to_vec3::<u32>()?, &[[[7998000]]]);
|
||||
assert_eq!(
|
||||
tensor.sum(&[0])?.sum(&[2])?.sum(&[1])?.to_vec3::<u32>()?,
|
||||
tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum(&[0])?.sum(&[1, 2])?.to_vec3::<u32>()?,
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[2])?
|
||||
.sum_keepdim(&[1])?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum(&[1])?.sum(&[0, 2])?.to_vec3::<u32>()?,
|
||||
tensor
|
||||
.sum_keepdim(&[0])?
|
||||
.sum_keepdim(&[1, 2])?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum(&[0])?.to_vec3::<u32>()?,
|
||||
tensor
|
||||
.sum_keepdim(&[1])?
|
||||
.sum_keepdim(&[0, 2])?
|
||||
.to_vec3::<u32>()?,
|
||||
&[[[7998000]]]
|
||||
);
|
||||
assert_eq!(
|
||||
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
|
||||
&[[
|
||||
[398000, 398200, 398400, 398600],
|
||||
[398800, 399000, 399200, 399400],
|
||||
|
Reference in New Issue
Block a user