diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 2b873e6e..8815c08d 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -40,6 +40,16 @@ pub trait BackendStorage: Sized { ) -> Result; fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result; + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result; fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result; fn index_add( &self, diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 678dbabd..38898b7b 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -39,6 +39,7 @@ impl Tensor { } else if let Some(op) = node.op() { match op { Op::IndexAdd(t1, t2, t3, _) + | Op::ScatterAdd(t1, t2, t3, _) | Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => { let (tg, nodes) = walk(t1, nodes, already_seen); @@ -56,6 +57,7 @@ impl Tensor { } | Op::CustomOp2(lhs, rhs, _) | Op::Binary(lhs, rhs, _) + | Op::Gather(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _) | Op::Embedding(lhs, rhs) | Op::Matmul(lhs, rhs) => { @@ -162,6 +164,11 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::Gather(arg, indexes, dim) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; + } + Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?, Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?, Op::IndexSelect(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 9e2d8699..b8d52c95 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -628,6 +628,59 @@ impl Map1 for Affine { } } +struct Gather<'a> { + ids: &'a [u32], + ids_l: &'a Layout, + dim: usize, +} + +impl<'a> Map1 for Gather<'a> { + fn f(&self, src: &[T], src_l: &Layout) -> Result> { + let ids = match self.ids_l.contiguous_offsets() { + Some((a, b)) => &self.ids[a..b], + None => Err(Error::RequiresContiguous { op: "gather" })?, + }; + let src = match src_l.contiguous_offsets() { + Some((a, b)) => &src[a..b], + None => Err(Error::RequiresContiguous { op: "gather" })?, + }; + let dim = self.dim; + let ids_dims = self.ids_l.dims(); + let src_dims = src_l.dims(); + let dst_len: usize = ids_dims.iter().product(); + let dst_left_len: usize = ids_dims[..dim].iter().product(); + let dst_dim_len = ids_dims[dim]; + let dst_right_len: usize = ids_dims[dim + 1..].iter().product(); + + let src_dim_len = src_dims[dim]; + let src_right_len: usize = src_dims[dim + 1..].iter().product(); + + let mut dst = vec![T::zero(); dst_len]; + for left_i in 0..dst_left_len { + let start_src_idx = left_i * src_right_len * src_dim_len; + let start_dst_idx = left_i * dst_right_len * dst_dim_len; + for i in 0..dst_dim_len { + let start_dst_idx = start_dst_idx + i * dst_right_len; + for right_i in 0..dst_right_len { + let dst_idx = start_dst_idx + right_i; + let index = ids[dst_idx] as usize; + if index >= src_dim_len { + Err(Error::InvalidIndex { + index, + size: src_dim_len, + op: "gather", + } + .bt())? + } + let src_idx = start_src_idx + index * src_right_len + right_i; + dst[dst_idx] = src[src_idx] + } + } + } + Ok(dst) + } +} + struct IndexSelect<'a> { ids: &'a [u32], ids_l: &'a Layout, @@ -680,6 +733,63 @@ impl<'a> Map1 for IndexSelect<'a> { } } +struct ScatterAdd<'a> { + ids: &'a [u32], + ids_l: &'a Layout, + dim: usize, +} + +impl<'a> Map2 for ScatterAdd<'a> { + const OP: &'static str = "scatter-add"; + fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> { + let dst_len = l1.shape().elem_count(); + let mut dst = vec![T::zero(); dst_len]; + copy_strided_src_(v1, &mut dst, 0, l1); + let src = match src_l.contiguous_offsets() { + None => Err(Error::RequiresContiguous { op: "scatter-add" })?, + Some((o1, o2)) => &src[o1..o2], + }; + + let dim = self.dim; + let ids_dims = self.ids_l.dims(); + let dst_dims = l1.dims(); + let dst_dim_len = dst_dims[dim]; + let dst_right_len: usize = dst_dims[dim + 1..].iter().product(); + + let ids_left_len: usize = ids_dims[..dim].iter().product(); + let ids_dim_len = ids_dims[dim]; + let ids_right_len: usize = ids_dims[dim + 1..].iter().product(); + + let ids = match self.ids_l.contiguous_offsets() { + Some((a, b)) => &self.ids[a..b], + None => Err(Error::RequiresContiguous { op: "gather" })?, + }; + for left_i in 0..ids_left_len { + let start_ids_idx = left_i * ids_right_len * ids_dim_len; + let start_dst_idx = left_i * dst_right_len * dst_dim_len; + for i in 0..ids_dim_len { + let start_ids_idx = start_ids_idx + i * ids_right_len; + for right_i in 0..dst_right_len { + let ids_idx = start_ids_idx + right_i; + let index = ids[ids_idx] as usize; + if index >= dst_dim_len { + Err(Error::InvalidIndex { + index, + size: dst_dim_len, + op: "gather", + } + .bt())? + } + let dst_idx = start_dst_idx + index * dst_right_len + right_i; + dst[dst_idx] += src[ids_idx] + } + } + } + + Ok(dst) + } +} + struct IndexAdd<'a> { ids: &'a [u32], dim: usize, @@ -1593,6 +1703,24 @@ impl BackendStorage for CpuStorage { IndexSelect { ids, ids_l, dim }.map(self, l) } + fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + let ids = ids.as_slice::()?; + Gather { ids, ids_l, dim }.map(self, l) + } + + fn scatter_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + let ids = ids.as_slice::()?; + ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l) + } + fn index_add( &self, l: &Layout, diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index a5633836..43bfef2d 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1064,6 +1064,20 @@ impl BackendStorage for CudaStorage { fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { Err(CudaError::InternalError("TODO: implement index-select").into()) } + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + Err(CudaError::InternalError("TODO: implement gather").into()) + } + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(CudaError::InternalError("TODO: implement scatter-add").into()) + } fn index_add( &self, _: &Layout, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 633f146e..c195cade 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -85,6 +85,22 @@ impl crate::backend::BackendStorage for CudaStorage { fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { Err(Error::NotCompiledWithCudaSupport) } + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn index_add( &self, _: &Layout, diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index d36aa301..de5094bd 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -66,6 +66,8 @@ pub(crate) enum Op { Reduce(Tensor, ReduceOp, Vec), Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), + Gather(Tensor, Tensor, usize), + ScatterAdd(Tensor, Tensor, Tensor, usize), IndexSelect(Tensor, Tensor, usize), IndexAdd(Tensor, Tensor, Tensor, usize), WhereCond(Tensor, Tensor, Tensor), diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 62e2d5e7..5e6cfdf2 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -325,6 +325,51 @@ impl Storage { } } + pub(crate) fn gather( + &self, + l: &Layout, + indexes: &Self, + indexes_l: &Layout, + d: usize, + ) -> Result { + self.same_device(indexes, "index-add")?; + match (self, indexes) { + (Self::Cpu(s), Self::Cpu(indexes)) => { + let storage = s.gather(l, indexes, indexes_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(s), Self::Cuda(indexes)) => { + let storage = s.gather(l, indexes, indexes_l, d)?; + Ok(Self::Cuda(storage)) + } + _ => unreachable!(), + } + } + + pub(crate) fn scatter_add( + &self, + l: &Layout, + indexes: &Self, + indexes_l: &Layout, + source: &Self, + source_l: &Layout, + d: usize, + ) -> Result { + self.same_device(indexes, "scatter-add")?; + self.same_device(source, "scatter-add")?; + match (self, indexes, source) { + (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => { + let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => { + let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cuda(storage)) + } + _ => unreachable!(), + } + } + pub(crate) fn index_add( &self, l: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 1d6e4e3f..8ba0ba43 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -945,6 +945,57 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "scatter-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (self, src)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + })? + } + if indexes.dims() != source.dims() { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (indexes, src)", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + })? + } + let storage = self.storage().scatter_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = if indexes.track_op() || self.track_op() { + Some(Op::ScatterAdd( + self.clone(), + indexes.clone(), + source.clone(), + dim, + )) + } else { + None + }; + Ok(from_storage(storage, self.shape(), op, false)) + } + pub fn index_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "index-add")?; let source_dims = source.dims(); @@ -992,6 +1043,40 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + pub fn gather(&self, indexes: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "gather")?; + let self_dims = self.dims(); + let indexes_dims = indexes.dims(); + let mismatch = if indexes_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "gather", + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + })? + } + let storage = + self.storage() + .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; + let op = if indexes.track_op() || self.track_op() { + Some(Op::Gather(self.clone(), indexes.clone(), dim)) + } else { + None + }; + Ok(from_storage(storage, indexes.shape(), op, false)) + } + pub fn index_select(&self, indexes: &Self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "index-select")?; let indexes_len = match indexes.dims() { diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index ea2dc0cd..2cfe4923 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -17,10 +17,11 @@ fn log_softmax(xs: &Tensor, d: D) -> candle::Result candle::Result { - let b_sz = target.shape().r1()?; - inp.index_select(target, 0)?.sum_all()? / b_sz as f64 +fn nll_loss(inp: &Tensor, target: &Tensor) -> candle::Result { + let b_sz = target.dim(0)?; + inp.gather(target, 1)? + .sum_all()? + .affine(-1f64 / b_sz as f64, 0.) } pub fn main() -> Result<()> { @@ -32,12 +33,7 @@ pub fn main() -> Result<()> { println!("test-labels: {:?}", m.test_labels.shape()); let train_labels = m.train_labels; let train_images = m.train_images; - let train_labels = train_labels.to_vec1::()?; - let train_label_mask = train_labels - .iter() - .flat_map(|l| (0..LABELS).map(|i| f32::from(i == *l as usize))) - .collect::>(); - let train_label_mask = Tensor::from_vec(train_label_mask, (train_labels.len(), LABELS), &dev)?; + let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?; let bs = Var::zeros(LABELS, DType::F32, &dev)?; let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0); @@ -46,9 +42,7 @@ pub fn main() -> Result<()> { for epoch in 1..200 { let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?; let log_sm = log_softmax(&logits, D::Minus1)?; - let loss = (&log_sm * &train_label_mask)? - .sum_all()? - .affine(-1f64 / train_images.dim(0)? as f64, 0f64)?; + let loss = nll_loss(&log_sm, &train_labels)?; sgd.backward_step(&loss)?; let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?; @@ -63,7 +57,7 @@ pub fn main() -> Result<()> { "{epoch:4} train loss: {:8.5} test acc: {:5.2}%", loss.to_scalar::()?, 100. * test_accuracy - ) + ); } Ok(()) }