diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 7b493d31..678dbabd 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -164,26 +164,8 @@ impl Tensor { Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?, Op::IndexSelect(arg, indexes, dim) => { - let dim = *dim; let sum_grad = grads.or_insert(arg)?; - // TODO: This is very very very inefficient, have some dedicated kernel for this. - // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add.html - let indexes = indexes.to_vec1::()?; - for (dst_index, src_index) in indexes.iter().enumerate() { - let src_index = *src_index as usize; - let dst_grad_for_index = grad.narrow(dim, dst_index, 1)?; - let mut pre_dims = arg.dims().to_vec(); - pre_dims[dim] = src_index; - let pre_zeros = - Tensor::zeros(pre_dims, sum_grad.dtype(), sum_grad.device())?; - let mut post_dims = arg.dims().to_vec(); - post_dims[dim] = post_dims[dim] - src_index - 1; - let post_zeros = - Tensor::zeros(post_dims, sum_grad.dtype(), sum_grad.device())?; - let src_grad = - Tensor::cat(&[pre_zeros, dst_grad_for_index, post_zeros], dim)?; - *sum_grad = sum_grad.add(&src_grad)?; - } + *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?; } Op::Embedding(_lhs, _rhs) => { Err(Error::BackwardNotSupported { op: "embedding" })? diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8e9b1d8e..9e2d8699 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -665,7 +665,7 @@ impl<'a> Map1 for IndexSelect<'a> { if index >= src_dim { Err(Error::InvalidIndex { index, - src_size: src_dim, + size: src_dim, op: "index-select", } .bt())? @@ -680,6 +680,72 @@ impl<'a> Map1 for IndexSelect<'a> { } } +struct IndexAdd<'a> { + ids: &'a [u32], + dim: usize, +} + +impl<'a> Map2 for IndexAdd<'a> { + const OP: &'static str = "index-add"; + // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_ + // v1, l1 -> self + fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> { + let dst_len = l1.shape().elem_count(); + let mut dst = vec![T::zero(); dst_len]; + copy_strided_src_(v1, &mut dst, 0, l1); + let src = match src_l.contiguous_offsets() { + None => Err(Error::RequiresContiguous { op: "index-add" })?, + Some((o1, o2)) => &src[o1..o2], + }; + let dim = self.dim; + let max_idx = l1.dims()[dim]; + let stride = src_l.stride()[dim]; + if dim == 0 { + for (src_idx, &dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx as usize; + if dst_idx >= max_idx { + Err(Error::InvalidIndex { + index: dst_idx, + op: "index-add", + size: max_idx, + })? + } + let src_idx = src_idx * stride; + let dst_idx = dst_idx * stride; + let src = &src[src_idx..src_idx + stride]; + let dst = &mut dst[dst_idx..dst_idx + stride]; + for (d, &s) in dst.iter_mut().zip(src.iter()) { + *d += s + } + } + } else { + let pre_dim = src_l.dims()[..dim].iter().product::(); + let post_dim = src_l.dims()[dim + 1..].iter().product::(); + for (src_idx, &dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx as usize; + if dst_idx >= max_idx { + Err(Error::InvalidIndex { + index: dst_idx, + op: "index-add", + size: max_idx, + })? + } + for pre_i in 0..pre_dim { + let pre_i = pre_i * stride; + let pre_src_i = (pre_i + src_idx) * post_dim; + let pre_dst_i = (pre_i + dst_idx) * post_dim; + let src = &src[pre_src_i..pre_src_i + post_dim]; + let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim]; + for (d, &s) in dst.iter_mut().zip(src.iter()) { + *d += s + } + } + } + } + Ok(dst) + } +} + struct Embedding<'a> { vocab_size: usize, hidden_size: usize, @@ -698,7 +764,7 @@ impl<'a> Map1 for Embedding<'a> { if index >= self.vocab_size { Err(Error::InvalidIndex { index, - src_size: self.vocab_size, + size: self.vocab_size, op: "take", } .bt())? @@ -711,12 +777,7 @@ impl<'a> Map1 for Embedding<'a> { } } -fn copy_strided_src_( - src: &[T], - dst: &mut [T], - dst_offset: usize, - src_l: &Layout, -) { +fn copy_strided_src_(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) { match src_l.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => { let to_copy = (dst.len() - dst_offset).min(len); @@ -1534,14 +1595,19 @@ impl BackendStorage for CpuStorage { fn index_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result { - todo!() + let ids = ids.as_slice::()?; + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" })?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) } fn matmul( diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 23f2642d..daf24e6a 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -112,11 +112,11 @@ pub enum Error { msg: &'static str, }, - #[error("{op} invalid index {index} with src dim size {src_size}")] + #[error("{op} invalid index {index} with dim size {size}")] InvalidIndex { op: &'static str, index: usize, - src_size: usize, + size: usize, }, #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d4ee34f9..1d6e4e3f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -947,6 +947,30 @@ impl Tensor { pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "index-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + // The number of element in indexes must match the dimension on which the add is + // performed on the source tensor (and the index values from `indexes` are taken from + // the target tensor self) + mismatch || source_dims[dim] != indexes.shape().r1()? + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "index-add", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + })? + } let storage = self.storage().index_add( self.layout(), &indexes.storage(),