diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 37a66cb5..645420e5 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -7,9 +7,9 @@ use candle::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; let t = Tensor::new(&[[1f32, 2., 3., 4.2]], &device)?; - let sum = t.sum_keepdim(&[0])?; + let sum = t.sum_keepdim(0)?; println!("{sum}"); - let sum = t.sum_keepdim(&[1])?; + let sum = t.sum_keepdim(1)?; println!("{sum}"); Ok(()) } diff --git a/candle-core/examples/cuda_sum_benchmark.rs b/candle-core/examples/cuda_sum_benchmark.rs index 86a1691d..bb4dc8be 100644 --- a/candle-core/examples/cuda_sum_benchmark.rs +++ b/candle-core/examples/cuda_sum_benchmark.rs @@ -27,16 +27,16 @@ fn main() -> Result<()> { let xys_cpu = cos_sin(n, &Device::Cpu)?; let xys = cos_sin(n, &device)?; println!("{xys_cpu:?} {xys:?}"); - let sum_keepdim_cpu = xys_cpu.sum_keepdim(&[1])?; + let sum_keepdim_cpu = xys_cpu.sum_keepdim(1)?; println!("{sum_keepdim_cpu}"); - let sum_keepdim = xys.sum_keepdim(&[1])?; + let sum_keepdim = xys.sum_keepdim(1)?; println!("{sum_keepdim}"); let start = std::time::Instant::now(); let n_iters = 100; let mut v = 0f32; for _i in 0..n_iters { - let sum_keepdim = xys.sum_keepdim(&[1])?; - let sum_keepdim = sum_keepdim.sum_keepdim(&[0])?; + let sum_keepdim = xys.sum_keepdim(1)?; + let sum_keepdim = sum_keepdim.sum_keepdim(0)?; let sum_keepdim: f32 = sum_keepdim.reshape(&[])?.to_scalar()?; v += sum_keepdim; } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 8746c2fe..d2648e66 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -37,6 +37,13 @@ pub enum Error { op: &'static str, }, + #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")] + DuplicateDimIndex { + shape: Shape, + dims: Vec, + op: &'static str, + }, + // === Shape Errors === #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")] UnexpectedNumberOfDims { diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index ad53f585..a267c068 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -256,6 +256,86 @@ impl Dim for D { } } +pub trait Dims: Sized { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result>; + + fn to_indexes(self, shape: &Shape, op: &'static str) -> Result> { + let dims = self.to_indexes_internal(shape, op)?; + for (i, &dim) in dims.iter().enumerate() { + if dims[..i].contains(&dim) { + Err(Error::DuplicateDimIndex { + shape: shape.clone(), + dims: dims.clone(), + op, + })? + } + if dim >= shape.rank() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: dim as i32, + op, + })? + } + } + Ok(dims) + } +} + +impl Dims for Vec { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(self) + } +} + +impl Dims for [usize; N] { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(self.to_vec()) + } +} + +impl Dims for &[usize] { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(self.to_vec()) + } +} + +impl Dims for () { + fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result> { + Ok(vec![]) + } +} + +impl Dims for D { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let dim = self.to_index(shape, op)?; + Ok(vec![dim]) + } +} + +impl Dims for (D,) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let dim = self.0.to_index(shape, op)?; + Ok(vec![dim]) + } +} + +impl Dims for (D1, D2) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + Ok(vec![d0, d1]) + } +} + +impl Dims for (D1, D2, D3) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + Ok(vec![d0, d1, d2]) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index af3675cc..c8353a70 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::shape::Dim; +use crate::shape::{Dim, Dims}; use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; @@ -572,7 +572,7 @@ impl Tensor { // We do not have a cuda kernel for divide_by_sum_over_dim so split // the operation. let exp = self.exp()?; - let sum_exp = exp.sum_keepdim(&[dim])?; + let sum_exp = exp.sum_keepdim(dim)?; exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); @@ -588,28 +588,9 @@ impl Tensor { } } - /// Returns the sum of all elements in the input tensor. The sum is performed over all the - /// input dimensions. - /// - /// The resulting tensor has a shape that is similar to the shape of the input tensor, except - /// that the number of elements for each dimension index in `sum_dims` is 1. - /// - /// ```rust - /// use candle::{Tensor, Device}; - /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; - /// let s = a.sum_keepdim(&[0])?; - /// assert_eq!(s.to_vec2::()?, &[[2., 4.]]); - /// let s = a.sum_keepdim(&[1])?; - /// assert_eq!(s.to_vec2::()?, &[[1.], [5.]]); - /// let s = a.sum_keepdim(&[0, 1])?; - /// assert_eq!(s.to_vec2::()?, &[[6.]]); - /// # Ok::<(), candle::Error>(()) - /// ``` - pub fn sum_keepdim(&self, sum_dims: &[usize]) -> Result { - for &dim in sum_dims { - self.check_dim(dim, "sum")?; - } - let storage = self.storage().sum(self.layout(), sum_dims)?; + pub fn sum_impl(&self, sum_dims: D, keepdim: bool) -> Result { + let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?; + let storage = self.storage().sum(self.layout(), &sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -619,33 +600,58 @@ impl Tensor { for &sum_dim in sum_dims.iter() { dims[sum_dim] = 1 } - Ok(from_storage(storage, dims, op, false)) + let sum = from_storage(storage, dims, op, false); + if keepdim { + Ok(sum) + } else { + match sum_dims.as_slice() { + [] => Ok(sum), + [i] => sum.squeeze(*i), + sum_dims => { + let dims = sum + .dims() + .iter() + .enumerate() + .filter_map(|(dim_idx, &v)| { + if sum_dims.contains(&dim_idx) { + None + } else { + Some(v) + } + }) + .collect::>(); + sum.reshape(dims) + } + } + } + } + + /// Returns the sum of all elements in the input tensor. The sum is performed over all the + /// input dimensions. + /// + /// The resulting tensor has a shape that is similar to the shape of the input tensor, except + /// that the number of elements for each dimension index in `sum_dims` is 1. + /// + /// ```rust + /// use candle::{Tensor, Device}; + /// let a = Tensor::new(&[[0f32, 1.], [2., 3.]], &Device::Cpu)?; + /// let s = a.sum_keepdim(0)?; + /// assert_eq!(s.to_vec2::()?, &[[2., 4.]]); + /// let s = a.sum_keepdim(1)?; + /// assert_eq!(s.to_vec2::()?, &[[1.], [5.]]); + /// let s = a.sum_keepdim((0, 1))?; + /// assert_eq!(s.to_vec2::()?, &[[6.]]); + /// # Ok::<(), candle::Error>(()) + /// ``` + pub fn sum_keepdim(&self, sum_dims: D) -> Result { + self.sum_impl(sum_dims, true) } /// Returns the sum of all elements in the input tensor. The sum is performed over all the /// input dimensions and compared to `sum_keepdim` these dimensions are squeezed rather than /// kept. - pub fn sum(&self, sum_dims: &[usize]) -> Result { - let sum = self.sum_keepdim(sum_dims)?; - match sum_dims { - [] => Ok(sum), - [i] => sum.squeeze(*i), - sum_dims => { - let dims = sum - .dims() - .iter() - .enumerate() - .filter_map(|(dim_idx, &v)| { - if sum_dims.contains(&dim_idx) { - None - } else { - Some(v) - } - }) - .collect::>(); - sum.reshape(dims) - } - } + pub fn sum(&self, sum_dims: D) -> Result { + self.sum_impl(sum_dims, false) } /// Applies a 1D convolution over the input tensor. @@ -962,7 +968,7 @@ impl Tensor { /// ``` pub fn sum_all(&self) -> Result { let dims: Vec<_> = (0..self.rank()).collect(); - self.sum_keepdim(&dims)?.reshape(()) + self.sum(dims) } fn flatten_( diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 241efc69..a7f6725d 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -19,7 +19,7 @@ fn simple_grad(device: &Device) -> Result<()> { fn sum_grad(device: &Device) -> Result<()> { let x = Var::new(&[3f32, 1., 4.], device)?; let x = x.as_tensor(); - let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?; + let y = (x.sqr()?.sum_keepdim(0)? * 2.)?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(y.to_vec1::()?, [52.]); @@ -27,7 +27,7 @@ fn sum_grad(device: &Device) -> Result<()> { assert_eq!(grad_x.to_vec1::()?, &[12., 4., 16.]); // Same test as before but squeezing on the last dimension. - let y = (x.sqr()?.sum_keepdim(&[0])? * 2.)?.squeeze(0)?; + let y = (x.sqr()?.sum_keepdim(0)? * 2.)?.squeeze(0)?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(y.to_scalar::()?, 52.); diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index b9e8a982..7b73cd7a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -108,65 +108,53 @@ fn sum(device: &Device) -> Result<()> { let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; let tensor = Tensor::new(data, device)?; assert_eq!( - tensor.sum_keepdim(&[2])?.to_vec3::()?, + tensor.sum_keepdim(2)?.to_vec3::()?, &[[[8], [15]], [[10], [18]]] ); assert_eq!( - tensor.sum_keepdim(&[0])?.to_vec3::()?, + tensor.sum_keepdim(0)?.to_vec3::()?, &[[[5, 2, 11], [9, 7, 17]]], ); - assert_eq!(tensor.sum_keepdim(&[0, 2, 1])?.to_vec3::()?, &[[[51]]],); + assert_eq!(tensor.sum_keepdim((0, 2, 1))?.to_vec3::()?, &[[[51]]],); assert_eq!( - tensor.t()?.sum_keepdim(&[1])?.t()?.to_vec3::()?, + tensor.t()?.sum_keepdim(1)?.t()?.to_vec3::()?, &[[[8], [15]], [[10], [18]]] ); assert_eq!( - tensor.sum_keepdim(&[2, 1])?.to_vec3::()?, + tensor.sum_keepdim((2, 1))?.to_vec3::()?, &[[[8 + 15]], [[10 + 18]]] ); let data: Vec = (0..4000u32).collect(); let tensor = Tensor::new(data.as_slice(), device)?; - assert_eq!(tensor.sum_keepdim(&[0])?.to_vec1::()?, &[7998000]); + assert_eq!(tensor.sum_keepdim(0)?.to_vec1::()?, &[7998000]); let tensor = tensor.reshape((2000, 2))?; - assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::()?, &[[7998000]]); assert_eq!( - tensor - .sum_keepdim(&[0])? - .sum_keepdim(&[1])? - .to_vec2::()?, + tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::()?, &[[7998000]] ); assert_eq!( - tensor - .sum_keepdim(&[1])? - .sum_keepdim(&[0])? - .to_vec2::()?, + tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::()?, &[[7998000]] ); assert_eq!( - tensor.sum_keepdim(&[0])?.to_vec2::()?, + tensor.sum_keepdim(0)?.to_vec2::()?, &[[3998000, 4000000]] ); // Make the tensor non contiguous. let tensor = tensor.t()?.contiguous()?.t()?; - assert_eq!(tensor.sum_keepdim(&[0, 1])?.to_vec2::()?, &[[7998000]]); + assert_eq!(tensor.sum_keepdim((0, 1))?.to_vec2::()?, &[[7998000]]); assert_eq!( - tensor - .sum_keepdim(&[0])? - .sum_keepdim(&[1])? - .to_vec2::()?, + tensor.sum_keepdim(0)?.sum_keepdim(1)?.to_vec2::()?, &[[7998000]] ); assert_eq!( - tensor - .sum_keepdim(&[1])? - .sum_keepdim(&[0])? - .to_vec2::()?, + tensor.sum_keepdim(1)?.sum_keepdim(0)?.to_vec2::()?, &[[7998000]] ); assert_eq!( - tensor.sum_keepdim(&[0])?.to_vec2::()?, + tensor.sum_keepdim(0)?.to_vec2::()?, &[[3998000, 4000000]] ); @@ -174,33 +162,33 @@ fn sum(device: &Device) -> Result<()> { let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; for tensor in [t1, t2] { assert_eq!( - tensor.sum_keepdim(&[0, 1, 2])?.to_vec3::()?, + tensor.sum_keepdim((0, 1, 2))?.to_vec3::()?, &[[[7998000]]] ); assert_eq!( tensor - .sum_keepdim(&[0])? - .sum_keepdim(&[2])? - .sum_keepdim(&[1])? + .sum_keepdim(0)? + .sum_keepdim(2)? + .sum_keepdim(1)? .to_vec3::()?, &[[[7998000]]] ); assert_eq!( tensor - .sum_keepdim(&[0])? - .sum_keepdim(&[1, 2])? + .sum_keepdim(0)? + .sum_keepdim((1, 2))? .to_vec3::()?, &[[[7998000]]] ); assert_eq!( tensor - .sum_keepdim(&[1])? - .sum_keepdim(&[0, 2])? + .sum_keepdim(1)? + .sum_keepdim((0, 2))? .to_vec3::()?, &[[[7998000]]] ); assert_eq!( - tensor.sum_keepdim(&[0])?.to_vec3::()?, + tensor.sum_keepdim(0)?.to_vec3::()?, &[[ [398000, 398200, 398400, 398600], [398800, 399000, 399200, 399400], diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index d7df5ae3..bf419072 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -604,7 +604,7 @@ fn main() -> Result<()> { println!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?; - let embeddings = (embeddings.sum_keepdim(&[1])? / (n_tokens as f64))?.squeeze(1)?; + let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; println!("pooled embeddings {:?}", embeddings.shape()); let mut similarities = vec![]; for i in 0..n_sentences { diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 57f339b0..ce2e6d2e 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -95,7 +95,7 @@ impl RmsNorm { // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; let (b_sz, seq_len, hidden_size) = x.shape().r3()?; - let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let size = self.scale.shape().r1()?; diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs index 652c47a7..31b1a162 100644 --- a/candle-examples/examples/musicgen/nn.rs +++ b/candle-examples/examples/musicgen/nn.rs @@ -70,7 +70,7 @@ pub fn conv1d_weight_norm( ) -> Result { let weight_g = vb.get((out_c, 1, 1), "weight_g")?; let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; - let norm_v = weight_v.sqr()?.sum_keepdim(&[1, 2])?.sqrt()?; + let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; let bias = vb.get(out_c, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 2119cf9b..15945d4e 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -98,7 +98,7 @@ impl T5LayerNorm { let dtype = xs.dtype(); let xs_f32 = xs.to_dtype(DType::F32)?; let xs2_f32 = (&xs_f32 * &xs_f32)?; - let sum_xs2_f32 = xs2_f32.sum_keepdim(&[xs.rank() - 1])?; + let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?; let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?; let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 06f984f2..88d5ab32 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -51,9 +51,9 @@ impl LayerNorm { }; let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; let x = x.to_dtype(internal_dtype)?; - let mean_x = (x.sum_keepdim(&[2])? / hidden_size as f64)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; let x = x.broadcast_sub(&mean_x)?; - let norm_x = (x.sqr()?.sum_keepdim(&[2])? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed .to_dtype(x_dtype)? diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs index fda7dbad..3a300cec 100644 --- a/candle-nn/tests/layer_norm.rs +++ b/candle-nn/tests/layer_norm.rs @@ -30,10 +30,10 @@ fn layer_norm() -> Result<()> { [4.1742344, 0.5, -3.1742344] ]] ); - let mean = (res.sum_keepdim(&[2])? / 3.0)?; + let mean = (res.sum_keepdim(2)? / 3.0)?; // The average value should be `b`. assert_eq!(mean.to_vec3::()?, [[[0.5], [0.5], [0.5]]]); - let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(&[2])?.sqrt()? / 3.0)?; + let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?; // The standard deviation should be sqrt(`w`). assert_eq!( std.to_vec3::()?,