From 38276855249c28a3eaaf116aaaaac0cb2387efa8 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 25 Apr 2025 21:46:58 +0200 Subject: [PATCH] Add the scatter op. (#2921) * Add the scatter op. * Backprop support. * Cuda support. --- candle-core/src/backend.rs | 9 ++++ candle-core/src/backprop.rs | 13 ++++- candle-core/src/cpu_backend/mod.rs | 64 +++++++++++++++++++++---- candle-core/src/cuda_backend/mod.rs | 63 ++++++++++++++++++++++++ candle-core/src/dummy_cuda_backend.rs | 12 +++++ candle-core/src/dummy_metal_backend.rs | 12 +++++ candle-core/src/metal_backend/mod.rs | 52 +++++++++++++++++++- candle-core/src/op.rs | 1 + candle-core/src/storage.rs | 28 +++++++++++ candle-core/src/tensor.rs | 46 ++++++++++++++++++ candle-core/tests/tensor_tests.rs | 32 ++++++++++--- candle-kernels/src/indexing.cu | 59 +++++++++++++++++++++++ candle-metal-kernels/src/indexing.metal | 53 ++++++++++++++++++++ candle-metal-kernels/src/lib.rs | 2 +- candle-metal-kernels/src/tests.rs | 2 +- 15 files changed, 429 insertions(+), 19 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 8ab59f4a..f3655065 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -71,6 +71,15 @@ pub trait BackendStorage: Sized { fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; + fn scatter( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result; fn scatter_add( &self, _: &Layout, diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index d8f1b786..a9577013 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -53,6 +53,7 @@ impl Tensor { } else if let Some(op) = node.op() { match op { Op::IndexAdd(t1, t2, t3, _) + | Op::Scatter(t1, t2, t3, _) | Op::ScatterAdd(t1, t2, t3, _) | Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => { @@ -419,7 +420,7 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; } - Op::ScatterAdd(init, indexes, src, dim) => { + Op::Scatter(init, indexes, src, dim) => { let init_sum_grad = grads.or_insert(init)?; *init_sum_grad = init_sum_grad.add(&grad)?; @@ -427,6 +428,16 @@ impl Tensor { let src_sum_grad = grads.or_insert(src)?; *src_sum_grad = src_sum_grad.add(&src_grad)?; } + Op::ScatterAdd(init, indexes, src, dim) => { + let init_sum_grad = grads.or_insert(init)?; + let mask = init.ones_like()?; + let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?; + *init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?; + + let src_grad = grad.gather(indexes, *dim)?; + let src_sum_grad = grads.or_insert(src)?; + *src_sum_grad = src_sum_grad.add(&src_grad)?; + } Op::IndexAdd(init, indexes, src, dim) => { let init_sum_grad = grads.or_insert(init)?; *init_sum_grad = init_sum_grad.add(&grad)?; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index a405320c..c9edeb5b 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -554,20 +554,51 @@ impl Map1 for IndexSelect<'_, I> { } } -struct ScatterAdd<'a, I: IntDType> { +trait ElemUpdate { + fn f(dst: &mut T, src: T); +} + +struct Set; +struct Add; + +impl ElemUpdate for Set { + fn f(dst: &mut T, src: T) { + *dst = src + } +} + +impl ElemUpdate for Add { + fn f(dst: &mut T, src: T) { + *dst += src + } +} + +struct Scatter<'a, I: IntDType, M: ElemUpdate> { ids: &'a [I], ids_l: &'a Layout, dim: usize, + _phantom: std::marker::PhantomData, } -impl Map2 for ScatterAdd<'_, I> { - const OP: &'static str = "scatter-add"; +impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> { + fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self { + Self { + ids, + ids_l, + dim, + _phantom: Default::default(), + } + } +} + +impl Map2 for Scatter<'_, I, M> { + const OP: &'static str = "scatter"; fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> { let dst_len = l1.shape().elem_count(); let mut dst = vec![T::zero(); dst_len]; copy_strided_src_(v1, &mut dst, 0, l1); let src = match src_l.contiguous_offsets() { - None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?, + None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?, Some((o1, o2)) => &src[o1..o2], }; @@ -602,7 +633,7 @@ impl Map2 for ScatterAdd<'_, I> { .bt())? } let dst_idx = start_dst_idx + index * dst_right_len + right_i; - dst[dst_idx] += src[ids_idx] + M::f(&mut dst[dst_idx], src[ids_idx]) } } } @@ -2381,6 +2412,23 @@ impl BackendStorage for CpuStorage { } } + fn scatter( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + match ids { + Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()), + } + } + fn scatter_add( &self, l: &Layout, @@ -2391,9 +2439,9 @@ impl BackendStorage for CpuStorage { dim: usize, ) -> Result { match ids { - Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), - Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), - Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), + Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), } } diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 00765af9..c36339b0 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -552,6 +552,54 @@ impl Map2InPlace for IndexAdd<'_> { } } +struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize); +impl Map2InPlace for Scatter<'_> { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()> { + let ids = &self.0; + let ids_l = &self.1; + let dim = self.2; + let (ids_o1, _) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?, + }; + let (name, (ids, _guard)) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)), + CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)), + CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)), + _ => Err(CudaError::UnexpectedDType { + msg: "scatter ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?, + }; + let left_sz: usize = src_l.dims()[..dim].iter().product(); + let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_sz = src_l.dims()[dim]; + let dst_dim_sz = dst_shape.dims()[dim]; + let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); + let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; + let mut builder = func.builder(); + barg!(builder, ids); + builder.arg(&src); + builder.arg(dst); + barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz); + // SAFETY: ffi. + unsafe { builder.launch(cfg) }.w()?; + Ok(()) + } +} + struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize); impl Map2InPlace for ScatterAdd<'_> { fn f( @@ -1838,6 +1886,21 @@ impl BackendStorage for CudaStorage { let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?; Ok(Self { slice, device }) } + fn scatter( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + let device = self.device().clone(); + let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; + self.copy_strided_src(&mut acc, 0, l)?; + Scatter(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; + Ok(acc) + } fn scatter_add( &self, l: &Layout, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 358081a0..0d635d75 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -128,6 +128,18 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn scatter( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn scatter_add( &self, _: &Layout, diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 434e8d7b..80493024 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -132,6 +132,18 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } + fn scatter( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + Err(Error::NotCompiledWithMetalSupport) + } + fn scatter_add( &self, _: &Layout, diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index e529c3f5..c609ebd7 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1426,6 +1426,56 @@ impl BackendStorage for MetalStorage { Ok(Self::new(buffer, device.clone(), dst_el, dtype)) } + fn scatter( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result { + let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + if !ids_l.is_contiguous() || !src_l.is_contiguous() { + return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt()); + }; + let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::F32) => "s_u8_f32", + (DType::U8, DType::F16) => "s_u8_f16", + (DType::U8, DType::BF16) => "s_u8_bf16", + (DType::U32, DType::U32) => "s_u32_u32", + (DType::U32, DType::F32) => "s_u32_f32", + (DType::U32, DType::F16) => "s_u32_f16", + (DType::U32, DType::BF16) => "s_u32_bf16", + (DType::I64, DType::F32) => "s_i64_f32", + (DType::I64, DType::F16) => "s_i64_f16", + (DType::I64, DType::BF16) => "s_i64_bf16", + _ => Err(MetalError::UnexpectedDType { + msg: "scatter ids should be u8/u32/i64", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let command_buffer = self.device.command_buffer()?; + let src = buffer_o(&src.buffer, src_l, src.dtype); + let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); + candle_metal_kernels::call_scatter( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + l.dims(), + dim, + src, + ids, + &acc.buffer, + ) + .map_err(MetalError::from)?; + Ok(acc) + } + fn scatter_add( &self, l: &Layout, @@ -1460,7 +1510,7 @@ impl BackendStorage for MetalStorage { let command_buffer = self.device.command_buffer()?; let src = buffer_o(&src.buffer, src_l, src.dtype); let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); - candle_metal_kernels::call_scatter_add( + candle_metal_kernels::call_scatter( &self.device.device, &command_buffer, &self.device.kernels, diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index c5fc3fc4..e2627f76 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -80,6 +80,7 @@ pub enum Op { Reduce(Tensor, ReduceOp, Vec), Matmul(Tensor, Tensor), Gather(Tensor, Tensor, usize), + Scatter(Tensor, Tensor, Tensor, usize), ScatterAdd(Tensor, Tensor, Tensor, usize), IndexSelect(Tensor, Tensor, usize), IndexAdd(Tensor, Tensor, Tensor, usize), diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 3148a00a..4257481b 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -628,6 +628,34 @@ impl Storage { } } + pub(crate) fn scatter( + &self, + l: &Layout, + indexes: &Self, + indexes_l: &Layout, + source: &Self, + source_l: &Layout, + d: usize, + ) -> Result { + self.same_device(indexes, "scatter-add")?; + self.same_device(source, "scatter-add")?; + match (self, indexes, source) { + (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => { + let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => { + let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Cuda(storage)) + } + (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { + let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?; + Ok(Self::Metal(storage)) + } + _ => unreachable!(), + } + } + pub(crate) fn scatter_add( &self, l: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index cd51ccbc..26e2e3b5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1354,6 +1354,52 @@ impl Tensor { self.index_select(ids, 0) } + pub fn scatter(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "scatter")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter (self, src)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + if indexes.dims() != source.dims() { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter (indexes, src)", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + } + .bt())? + } + let storage = self.storage().scatter( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| { + Op::Scatter(t1, t2, t3, dim) + }); + Ok(from_storage(storage, self.shape(), op, false)) + } + pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "scatter-add")?; let source_dims = source.dims(); diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 7d33f9d7..7e2d41ba 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1027,7 +1027,7 @@ fn slice_scatter(device: &Device) -> Result<()> { Ok(()) } -fn scatter_add(device: &Device) -> Result<()> { +fn scatter(device: &Device) -> Result<()> { let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; assert_eq!( t.to_vec2::()?, @@ -1051,6 +1051,17 @@ fn scatter_add(device: &Device) -> Result<()> { ] ); + let hs = init.scatter(&ids, &t, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 1.0, 2.0, 1.0, 1.0], + [5.0, 1.0, 1.0, 3.0, 4.0], + [1.0, 8.0, 1.0, 7.0, 1.0], + [10.0, 1.0, 9.0, 1.0, 11.0] + ] + ); + let init = Tensor::ones((6, 3), DType::F32, device)?; let hs = init.scatter_add(&ids, &t, 0)?; assert_eq!( @@ -1064,6 +1075,18 @@ fn scatter_add(device: &Device) -> Result<()> { [1.0, 1.0, 1.0] ] ); + let hs = init.scatter(&ids, &t, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 10.0, 5.0], + [1.0, 1.0, 8.0], + [9.0, 1.0, 2.0], + [6.0, 7.0, 1.0], + [1.0, 4.0, 11.0], + [1.0, 1.0, 1.0] + ] + ); Ok(()) } @@ -1563,12 +1586,7 @@ test_device!( ); test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal); test_device!(gather, gather_cpu, gather_gpu, gather_metal); -test_device!( - scatter_add, - scatter_add_cpu, - scatter_add_gpu, - scatter_add_metal -); +test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal); test_device!( slice_scatter, slice_scatter_cpu, diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 7074fa0b..f2327f27 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -114,6 +114,30 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +template +__device__ void scatter( + const I *ids, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (pre * src_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + assert(idx < dst_dim_size); + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = inp[src_i]; + } + } +} + template __device__ void scatter_add( const I *ids, @@ -138,6 +162,17 @@ __device__ void scatter_add( } } +#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + #define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const INDEX_TYPENAME *ids, \ @@ -163,6 +198,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) +S_OP(__nv_bfloat16, int64_t, s_i64_bf16) +S_OP(__nv_bfloat16, uint32_t, s_u32_bf16) +S_OP(__nv_bfloat16, uint8_t, s_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -178,6 +216,9 @@ IA_OP(__half, uint8_t, ia_u8_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) +S_OP(__half, int64_t, s_i64_f16) +S_OP(__half, uint32_t, s_u32_f16) +S_OP(__half, uint8_t, s_u8_f16) #endif IS_OP(float, int64_t, is_i64_f32) @@ -251,3 +292,21 @@ SA_OP(double, uint8_t, sa_u8_f64) SA_OP(uint8_t, uint8_t, sa_u8_u8) SA_OP(uint32_t, uint8_t, sa_u8_u32) SA_OP(int64_t, uint8_t, sa_u8_i64) + +S_OP(float, int64_t, s_i64_f32) +S_OP(double, int64_t, s_i64_f64) +S_OP(uint8_t, int64_t, s_i64_u8) +S_OP(int64_t, int64_t, s_i64_i64) +S_OP(uint32_t, int64_t, s_i64_u32) + +S_OP(float, uint32_t, s_u32_f32) +S_OP(double, uint32_t, s_u32_f64) +S_OP(uint8_t, uint32_t, s_u32_u8) +S_OP(int64_t, uint32_t, s_u32_i64) +S_OP(uint32_t, uint32_t, s_u32_u32) + +S_OP(float, uint8_t, s_u8_f32) +S_OP(double, uint8_t, s_u8_f64) +S_OP(uint8_t, uint8_t, s_u8_u8) +S_OP(uint32_t, uint8_t, s_u8_u32) +S_OP(int64_t, uint8_t, s_u8_i64) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index df374d20..d596a619 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -104,6 +104,31 @@ kernel void NAME( \ gather(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ } +template +METAL_FUNC void scatter( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const INDEX_TYPENAME idx = input_ids[src_i]; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] = input[src_i]; + } +} + template METAL_FUNC void scatter_add( constant size_t &dst_size, @@ -129,6 +154,21 @@ METAL_FUNC void scatter_add( } } +# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &dst_dim_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + scatter(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \ +} + # define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \ kernel void NAME( \ constant size_t &dst_size, \ @@ -235,6 +275,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat) SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) #endif +SCATTER_OP(s_u32_f32, uint32_t, float) +SCATTER_OP(s_u8_f32, uint8_t, float) +SCATTER_OP(s_i64_f32, int64_t, float) +SCATTER_OP(s_u32_u32, uint32_t, uint32_t) +SCATTER_OP(s_u32_f16, uint32_t, half) +SCATTER_OP(s_u8_f16, uint8_t, half) +SCATTER_OP(s_i64_f16, int64_t, half) +#if defined(__HAVE_BFLOAT__) +SCATTER_OP(s_u32_bf16, uint32_t, bfloat) +SCATTER_OP(s_u8_bf16, uint8_t, bfloat) +SCATTER_OP(s_i64_bf16, int64_t, bfloat) +#endif + // i64 INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index be31f824..9f689a07 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1447,7 +1447,7 @@ pub fn call_gather( } #[allow(clippy::too_many_arguments)] -pub fn call_scatter_add( +pub fn call_scatter( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 9121f671..ee130d6b 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1574,7 +1574,7 @@ fn run_scatter_add( let input_buffer = new_buffer(&device, input); let ids_buffer = new_buffer(&device, ids); let output = device.new_buffer(std::mem::size_of_val(input) as u64, options); - call_scatter_add( + call_scatter( &device, command_buffer, &kernels,