Simplify the parameters used by sum and sum_keepdim. (#165)

This commit is contained in:
Laurent Mazare
2023-07-14 08:22:08 +01:00
committed by GitHub
parent 2bfa791336
commit a2f72edc0d
13 changed files with 179 additions and 98 deletions

View File

@ -7,9 +7,9 @@ use candle::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?;
let sum = t.sum_keepdim(&[0])?;
let sum = t.sum_keepdim(0)?;
println!("{sum}");
let sum = t.sum_keepdim(&[1])?;
let sum = t.sum_keepdim(1)?;
println!("{sum}");
Ok(())
}

View File

@ -27,16 +27,16 @@ fn main() -> Result<()> {
let xys_cpu = cos_sin(n, &Device::Cpu)?;
let xys = cos_sin(n, &device)?;
println!("{xys_cpu:?} {xys:?}");
let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?;
let sum_keepdim_cpu = xys_cpu.sum_keepdim(1)?;
println!("{sum_keepdim_cpu}");
let sum_keepdim = xys.sum_keepdim(&[1])?;
let sum_keepdim = xys.sum_keepdim(1)?;
println!("{sum_keepdim}");
let start = std::time::Instant::now();
let n_iters = 100;
let mut v = 0f32;
for _i in 0..n_iters {
let sum_keepdim = xys.sum_keepdim(&[1])?;
let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?;
let sum_keepdim = xys.sum_keepdim(1)?;
let sum_keepdim = sum_keepdim.sum_keepdim(0)?;
let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?;
v += sum_keepdim;
}

View File

@ -37,6 +37,13 @@ pub enum Error {
op: &'static str,
},
#[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
DuplicateDimIndex {
shape: Shape,
dims: Vec<usize>,
op: &'static str,
},
// === Shape Errors ===
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
UnexpectedNumberOfDims {

View File

@ -256,6 +256,86 @@ impl Dim for D {
}
}
pub trait Dims: Sized {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let dims = self.to_indexes_internal(shape, op)?;
for (i, &dim) in dims.iter().enumerate() {
if dims[..i].contains(&dim) {
Err(Error::DuplicateDimIndex {
shape: shape.clone(),
dims: dims.clone(),
op,
})?
}
if dim >= shape.rank() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
dim: dim as i32,
op,
})?
}
}
Ok(dims)
}
}
impl Dims for Vec<usize> {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(self)
}
}
impl<const N: usize> Dims for [usize; N] {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(self.to_vec())
}
}
impl Dims for &[usize] {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(self.to_vec())
}
}
impl Dims for () {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(vec![])
}
}
impl<D: Dim + Sized> Dims for D {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let dim = self.to_index(shape, op)?;
Ok(vec![dim])
}
}
impl<D: Dim> Dims for (D,) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let dim = self.0.to_index(shape, op)?;
Ok(vec![dim])
}
}
impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
Ok(vec![d0, d1])
}
}
impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
Ok(vec![d0, d1, d2])
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::shape::Dim;
use crate::shape::{Dim, Dims};
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -572,7 +572,7 @@ impl Tensor {
// We do not have a cuda kernel for divide_by_sum_over_dim so split
// the operation.
let exp = self.exp()?;
let sum_exp = exp.sum_keepdim(&[dim])?;
let sum_exp = exp.sum_keepdim(dim)?;
exp.broadcast_div(&sum_exp)
} else {
let shape = self.shape();
@ -588,28 +588,9 @@ impl Tensor {
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions.
///
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
/// that the number of elements for each dimension index in `sum_dims` is 1.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let s = a.sum_keepdim(&[0])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
/// let s = a.sum_keepdim(&[1])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
/// let s = a.sum_keepdim(&[0, 1])?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result<Self> {
for &dim in sum_dims {
self.check_dim(dim, "sum")?;
}
let storage = self.storage().sum(self.layout(), sum_dims)?;
pub fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
let storage = self.storage().sum(self.layout(), &sum_dims)?;
let op = if self.track_op() {
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
} else {
@ -619,33 +600,58 @@ impl Tensor {
for &sum_dim in sum_dims.iter() {
dims[sum_dim] = 1
}
Ok(from_storage(storage, dims, op, false))
let sum = from_storage(storage, dims, op, false);
if keepdim {
Ok(sum)
} else {
match sum_dims.as_slice() {
[] => Ok(sum),
[i] => sum.squeeze(*i),
sum_dims => {
let dims = sum
.dims()
.iter()
.enumerate()
.filter_map(|(dim_idx, &v)| {
if sum_dims.contains(&dim_idx) {
None
} else {
Some(v)
}
})
.collect::<Vec<_>>();
sum.reshape(dims)
}
}
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions.
///
/// The resulting tensor has a shape that is similar to the shape of the input tensor, except
/// that the number of elements for each dimension index in `sum_dims` is 1.
///
/// ```rust
/// use candle::{Tensor, Device};
/// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?;
/// let s = a.sum_keepdim(0)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[2., 4.]]);
/// let s = a.sum_keepdim(1)?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[1.], [5.]]);
/// let s = a.sum_keepdim((0, 1))?;
/// assert_eq!(s.to_vec2::<f32>()?, &[[6.]]);
/// # Ok::<(), candle::Error>(())
/// ```
pub fn sum_keepdim<D: Dims>(&self, sum_dims: D) -> Result<Self> {
self.sum_impl(sum_dims, true)
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than
/// kept.
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
let sum = self.sum_keepdim(sum_dims)?;
match sum_dims {
[] => Ok(sum),
[i] => sum.squeeze(*i),
sum_dims => {
let dims = sum
.dims()
.iter()
.enumerate()
.filter_map(|(dim_idx, &v)| {
if sum_dims.contains(&dim_idx) {
None
} else {
Some(v)
}
})
.collect::<Vec<_>>();
sum.reshape(dims)
}
}
pub fn sum<D: Dims>(&self, sum_dims: D) -> Result<Self> {
self.sum_impl(sum_dims, false)
}
/// Applies a 1D convolution over the input tensor.
@ -962,7 +968,7 @@ impl Tensor {
/// ```
pub fn sum_all(&self) -> Result<Tensor> {
let dims: Vec<_> = (0..self.rank()).collect();
self.sum_keepdim(&dims)?.reshape(())
self.sum(dims)
}
fn flatten_<D1: Dim, D2: Dim>(

View File

@ -19,7 +19,7 @@ fn simple_grad(device: &Device) -> Result<()> {
fn sum_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
let x = x.as_tensor();
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?;
let y = (x.sqr()?.sum_keepdim(0)? * 2.)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [52.]);
@ -27,7 +27,7 @@ fn sum_grad(device: &Device) -> Result<()> {
assert_eq!(grad_x.to_vec1::<f32>()?, &[12., 4., 16.]);
// Same test as before but squeezing on the last dimension.
let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?.squeeze(0)?;
let y = (x.sqr()?.sum_keepdim(0)? * 2.)?.squeeze(0)?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_scalar::<f32>()?, 52.);

View File

@ -108,65 +108,53 @@ fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.sum_keepdim(&[2])?.to_vec3::<u32>()?,
tensor.sum_keepdim(2)?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]]
);
assert_eq!(
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
tensor.sum_keepdim(0)?.to_vec3::<u32>()?,
&[[[5, 2, 11], [9, 7, 17]]],
);
assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::<u32>()?, &[[[51]]],);
assert_eq!(tensor.sum_keepdim((0, 2, 1))?.to_vec3::<u32>()?, &[[[51]]],);
assert_eq!(
tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::<u32>()?,
tensor.t()?.sum_keepdim(1)?.t()?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]]
);
assert_eq!(
tensor.sum_keepdim(&[2, 1])?.to_vec3::<u32>()?,
tensor.sum_keepdim((2, 1))?.to_vec3::<u32>()?,
&[[[8 + 15]], [[10 + 18]]]
);
let data: Vec<u32> = (0..4000u32).collect();
let tensor = Tensor::new(data.as_slice(), device)?;
assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::<u32>()?, &[7998000]);
assert_eq!(tensor.sum_keepdim(0)?.to_vec1::<u32>()?, &[7998000]);
let tensor = tensor.reshape((2000, 2))?;
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[1])?
.to_vec2::<u32>()?,
tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor
.sum_keepdim(&[1])?
.sum_keepdim(&[0])?
.to_vec2::<u32>()?,
tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
tensor.sum_keepdim(0)?.to_vec2::<u32>()?,
&[[3998000, 4000000]]
);
// Make the tensor non contiguous.
let tensor = tensor.t()?.contiguous()?.t()?;
assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::<u32>()?, &[[7998000]]);
assert_eq!(
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[1])?
.to_vec2::<u32>()?,
tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor
.sum_keepdim(&[1])?
.sum_keepdim(&[0])?
.to_vec2::<u32>()?,
tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::<u32>()?,
&[[7998000]]
);
assert_eq!(
tensor.sum_keepdim(&[0])?.to_vec2::<u32>()?,
tensor.sum_keepdim(0)?.to_vec2::<u32>()?,
&[[3998000, 4000000]]
);
@ -174,33 +162,33 @@ fn sum(device: &Device) -> Result<()> {
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] {
assert_eq!(
tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::<u32>()?,
tensor.sum_keepdim((0, 1, 2))?.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[2])?
.sum_keepdim(&[1])?
.sum_keepdim(0)?
.sum_keepdim(2)?
.sum_keepdim(1)?
.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor
.sum_keepdim(&[0])?
.sum_keepdim(&[1, 2])?
.sum_keepdim(0)?
.sum_keepdim((1, 2))?
.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor
.sum_keepdim(&[1])?
.sum_keepdim(&[0, 2])?
.sum_keepdim(1)?
.sum_keepdim((0, 2))?
.to_vec3::<u32>()?,
&[[[7998000]]]
);
assert_eq!(
tensor.sum_keepdim(&[0])?.to_vec3::<u32>()?,
tensor.sum_keepdim(0)?.to_vec3::<u32>()?,
&[[
[398000, 398200, 398400, 398600],
[398800, 399000, 399200, 399400],