From a2e925462ce61cfaf877b69d769b995df4830a64 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 26 Apr 2025 07:36:49 +0200 Subject: [PATCH] Add the scatter in place ops. (#2923) * Add the scatter_set op. * Metal op. * Cuda version. * Merge the checks. * Add the actual ops. --- candle-core/src/backend.rs | 15 +++-- candle-core/src/cpu_backend/mod.rs | 36 ++++++---- candle-core/src/cpu_backend/utils.rs | 24 +++++++ candle-core/src/cuda_backend/mod.rs | 56 +++++++++------- candle-core/src/cuda_backend/utils.rs | 20 +++--- candle-core/src/dummy_cuda_backend.rs | 12 ++-- candle-core/src/dummy_metal_backend.rs | 12 ++-- candle-core/src/metal_backend/mod.rs | 30 ++++----- candle-core/src/storage.rs | 34 +++++----- candle-core/src/tensor.rs | 92 ++++++++++++++++---------- candle-core/tests/tensor_tests.rs | 12 ++++ candle-metal-kernels/src/lib.rs | 6 +- 12 files changed, 208 insertions(+), 141 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index f3655065..a85f8d36 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -71,24 +71,27 @@ pub trait BackendStorage: Sized { fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; - fn scatter( - &self, + + fn scatter_set( + &mut self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout, _: usize, - ) -> Result; - fn scatter_add( - &self, + ) -> Result<()>; + + fn scatter_add_set( + &mut self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout, _: usize, - ) -> Result; + ) -> Result<()>; + fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result; fn index_add( &self, diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index c9edeb5b..347710de 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -7,7 +7,7 @@ use rayon::prelude::*; mod utils; pub use utils::{ - binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8, + binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8, }; const USE_IM2COL_CONV1D: bool = true; @@ -591,12 +591,20 @@ impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> { } } -impl Map2 for Scatter<'_, I, M> { +impl Map2InPlace for Scatter<'_, I, M> { const OP: &'static str = "scatter"; - fn f(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result> { - let dst_len = l1.shape().elem_count(); - let mut dst = vec![T::zero(); dst_len]; - copy_strided_src_(v1, &mut dst, 0, l1); + fn f( + &self, + dst: &mut [T], + dst_l: &Layout, + src: &[T], + src_l: &Layout, + ) -> Result<()> { + let dst = match dst_l.contiguous_offsets() { + None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?, + Some((o1, o2)) => &mut dst[o1..o2], + }; + let src = match src_l.contiguous_offsets() { None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?, Some((o1, o2)) => &src[o1..o2], @@ -604,7 +612,7 @@ impl Map2 for Scatter<'_, I, M> { let dim = self.dim; let ids_dims = self.ids_l.dims(); - let dst_dims = l1.dims(); + let dst_dims = dst_l.dims(); let dst_dim_len = dst_dims[dim]; let dst_right_len: usize = dst_dims[dim + 1..].iter().product(); @@ -638,7 +646,7 @@ impl Map2 for Scatter<'_, I, M> { } } - Ok(dst) + Ok(()) } } @@ -2412,15 +2420,15 @@ impl BackendStorage for CpuStorage { } } - fn scatter( - &self, + fn scatter_set( + &mut self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, - ) -> Result { + ) -> Result<()> { match ids { Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l), Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l), @@ -2429,15 +2437,15 @@ impl BackendStorage for CpuStorage { } } - fn scatter_add( - &self, + fn scatter_add_set( + &mut self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, - ) -> Result { + ) -> Result<()> { match ids { Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l), diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 3e0c69b4..c404c3ad 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -58,6 +58,30 @@ pub trait Map2 { } } +pub trait Map2InPlace { + const OP: &'static str; + fn f(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>; + + fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?, + (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?, + (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?, + (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?, + (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?, + (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?, + (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?, + (v1, v2) => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt())?, + }; + Ok(()) + } +} + pub trait Map2U8 { const OP: &'static str; fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index c36339b0..95987ba0 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -2,7 +2,7 @@ //! use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType}; +use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; @@ -507,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> { fn f( &self, dst: &mut CudaSlice, - dst_shape: &Shape, + dst_l: &Layout, src: &CudaSlice, src_l: &Layout, dev: &CudaDevice, @@ -529,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> { got: ids.dtype(), })?, }; + let dst = match dst_l.contiguous_offsets() { + Some((o1, o2)) => dst.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, @@ -536,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> { let left_sz: usize = src_l.dims()[..dim].iter().product(); let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; - let dst_dim_sz = dst_shape.dims()[dim]; + let dst_dim_sz = dst_l.dims()[dim]; let ids_dim_sz = ids_l.dims()[0]; let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; @@ -544,7 +548,7 @@ impl Map2InPlace for IndexAdd<'_> { barg!(builder, ids); barg!(builder, ids_dim_sz); builder.arg(&src); - builder.arg(dst); + builder.arg(&dst); barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz); // SAFETY: ffi. unsafe { builder.launch(cfg) }.w()?; @@ -557,7 +561,7 @@ impl Map2InPlace for Scatter<'_> { fn f( &self, dst: &mut CudaSlice, - dst_shape: &Shape, + dst_l: &Layout, src: &CudaSlice, src_l: &Layout, dev: &CudaDevice, @@ -579,6 +583,10 @@ impl Map2InPlace for Scatter<'_> { got: ids.dtype(), })?, }; + let dst = match dst_l.contiguous_offsets() { + Some((o1, o2)) => dst.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?, + }; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?, @@ -586,13 +594,13 @@ impl Map2InPlace for Scatter<'_> { let left_sz: usize = src_l.dims()[..dim].iter().product(); let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; - let dst_dim_sz = dst_shape.dims()[dim]; + let dst_dim_sz = dst_l.dims()[dim]; let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; let mut builder = func.builder(); barg!(builder, ids); builder.arg(&src); - builder.arg(dst); + builder.arg(&dst); barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz); // SAFETY: ffi. unsafe { builder.launch(cfg) }.w()?; @@ -605,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> { fn f( &self, dst: &mut CudaSlice, - dst_shape: &Shape, + dst_l: &Layout, src: &CudaSlice, src_l: &Layout, dev: &CudaDevice, @@ -627,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> { got: ids.dtype(), })?, }; + let dst = match dst_l.contiguous_offsets() { + Some((o1, o2)) => dst.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, @@ -634,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> { let left_sz: usize = src_l.dims()[..dim].iter().product(); let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; - let dst_dim_sz = dst_shape.dims()[dim]; + let dst_dim_sz = dst_l.dims()[dim]; let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; let mut builder = func.builder(); barg!(builder, ids); builder.arg(&src); - builder.arg(dst); + builder.arg(&dst); barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz); // SAFETY: ffi. unsafe { builder.launch(cfg) }.w()?; @@ -1886,35 +1898,29 @@ impl BackendStorage for CudaStorage { let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?; Ok(Self { slice, device }) } - fn scatter( - &self, + fn scatter_set( + &mut self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, - ) -> Result { + ) -> Result<()> { let device = self.device().clone(); - let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; - self.copy_strided_src(&mut acc, 0, l)?; - Scatter(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; - Ok(acc) + Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device) } - fn scatter_add( - &self, + fn scatter_add_set( + &mut self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, - ) -> Result { + ) -> Result<()> { let device = self.device().clone(); - let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; - self.copy_strided_src(&mut acc, 0, l)?; - ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; - Ok(acc) + ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device) } fn index_add( &self, @@ -1928,7 +1934,7 @@ impl BackendStorage for CudaStorage { let device = self.device().clone(); let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; - IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; + IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?; Ok(acc) } diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index c1210727..0a81f0ac 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -1,5 +1,5 @@ /// Helper functions to plug cuda kernels in candle. -use crate::{Layout, Result, Shape, WithDType}; +use crate::{Layout, Result, WithDType}; pub use cudarc; use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits}; @@ -96,7 +96,7 @@ pub trait Map2InPlace { fn f( &self, dst: &mut CudaSlice, - dst_shape: &Shape, + dst_l: &Layout, src: &CudaSlice, src_l: &Layout, dev: &CudaDevice, @@ -105,19 +105,19 @@ pub trait Map2InPlace { fn map( &self, dst: &mut S, - dst_s: &Shape, + dst_l: &Layout, src: &S, src_l: &Layout, d: &CudaDevice, ) -> Result<()> { match (dst, src) { - (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), - (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), - (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d), - (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d), + (S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d), + (S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d), + (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d), + (S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d), + (S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d), + (S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 0d635d75..32909935 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -128,27 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn scatter( - &self, + fn scatter_set( + &mut self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout, _: usize, - ) -> Result { + ) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } - fn scatter_add( - &self, + fn scatter_add_set( + &mut self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout, _: usize, - ) -> Result { + ) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 80493024..de43f243 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -132,27 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn scatter( - &self, + fn scatter_set( + &mut self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout, _: usize, - ) -> Result { + ) -> Result<()> { Err(Error::NotCompiledWithMetalSupport) } - fn scatter_add( - &self, + fn scatter_add_set( + &mut self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout, _: usize, - ) -> Result { + ) -> Result<()> { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index c609ebd7..cdbeb65d 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1426,18 +1426,16 @@ impl BackendStorage for MetalStorage { Ok(Self::new(buffer, device.clone(), dst_el, dtype)) } - fn scatter( - &self, + fn scatter_set( + &mut self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, - ) -> Result { - let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; - self.copy_strided_src(&mut acc, 0, l)?; - if !ids_l.is_contiguous() || !src_l.is_contiguous() { + ) -> Result<()> { + if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() { return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt()); }; let name = match (ids.dtype, self.dtype) { @@ -1458,6 +1456,7 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; + let dst = buffer_o(&self.buffer, l, self.dtype); let src = buffer_o(&src.buffer, src_l, src.dtype); let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_scatter( @@ -1470,24 +1469,22 @@ impl BackendStorage for MetalStorage { dim, src, ids, - &acc.buffer, + dst, ) .map_err(MetalError::from)?; - Ok(acc) + Ok(()) } - fn scatter_add( - &self, + fn scatter_add_set( + &mut self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, - ) -> Result { - let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?; - self.copy_strided_src(&mut acc, 0, l)?; - if !ids_l.is_contiguous() || !src_l.is_contiguous() { + ) -> Result<()> { + if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() { return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { @@ -1508,6 +1505,7 @@ impl BackendStorage for MetalStorage { })?, }; let command_buffer = self.device.command_buffer()?; + let dst = buffer_o(&self.buffer, l, self.dtype); let src = buffer_o(&src.buffer, src_l, src.dtype); let ids = buffer_o(&ids.buffer, ids_l, ids.dtype); candle_metal_kernels::call_scatter( @@ -1520,10 +1518,10 @@ impl BackendStorage for MetalStorage { dim, src, ids, - &acc.buffer, + dst, ) .map_err(MetalError::from)?; - Ok(acc) + Ok(()) } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 4257481b..32af5824 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -628,60 +628,56 @@ impl Storage { } } - pub(crate) fn scatter( - &self, + pub(crate) fn scatter_set( + &mut self, l: &Layout, indexes: &Self, indexes_l: &Layout, source: &Self, source_l: &Layout, d: usize, - ) -> Result { - self.same_device(indexes, "scatter-add")?; - self.same_device(source, "scatter-add")?; + ) -> Result<()> { + self.same_device(indexes, "scatter-set")?; + self.same_device(source, "scatter-set")?; match (self, indexes, source) { (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => { - let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?; - Ok(Self::Cpu(storage)) + s.scatter_set(l, indexes, indexes_l, source, source_l, d)?; } (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => { - let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?; - Ok(Self::Cuda(storage)) + s.scatter_set(l, indexes, indexes_l, source, source_l, d)?; } (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { - let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?; - Ok(Self::Metal(storage)) + s.scatter_set(l, indexes, indexes_l, source, source_l, d)?; } _ => unreachable!(), } + Ok(()) } pub(crate) fn scatter_add( - &self, + &mut self, l: &Layout, indexes: &Self, indexes_l: &Layout, source: &Self, source_l: &Layout, d: usize, - ) -> Result { + ) -> Result<()> { self.same_device(indexes, "scatter-add")?; self.same_device(source, "scatter-add")?; match (self, indexes, source) { (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => { - let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; - Ok(Self::Cpu(storage)) + s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?; } (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => { - let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; - Ok(Self::Cuda(storage)) + s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?; } (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => { - let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?; - Ok(Self::Metal(storage)) + s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?; } _ => unreachable!(), } + Ok(()) } pub(crate) fn index_add( diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 26e2e3b5..fdbd2e45 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1354,8 +1354,7 @@ impl Tensor { self.index_select(ids, 0) } - pub fn scatter(&self, indexes: &Self, source: &Self, dim: D) -> Result { - let dim = dim.to_index(self.shape(), "scatter")?; + fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> { let source_dims = source.dims(); let self_dims = self.dims(); let mismatch = if source_dims.len() != self_dims.len() { @@ -1386,8 +1385,19 @@ impl Tensor { } .bt())? } - let storage = self.storage().scatter( - self.layout(), + Ok(()) + } + + pub fn scatter(&self, indexes: &Self, source: &Self, dim: D) -> Result { + let dim = dim.to_index(self.shape(), "scatter")?; + self.scatter_checks(indexes, source, dim)?; + let shape = self.shape(); + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let layout = Layout::contiguous(shape); + storage.scatter_set( + &layout, &indexes.storage(), indexes.layout(), &source.storage(), @@ -1400,40 +1410,33 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + pub fn scatter_set(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> { + if self.same_storage(source) { + crate::bail!("cannot use slice_set when self and src share their storage") + } + let dim = dim.to_index(self.shape(), "scatter-set")?; + self.scatter_checks(indexes, source, dim)?; + self.storage_mut().scatter_set( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + Ok(()) + } + pub fn scatter_add(&self, indexes: &Self, source: &Self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "scatter-add")?; - let source_dims = source.dims(); - let self_dims = self.dims(); - let mismatch = if source_dims.len() != self_dims.len() { - true - } else { - let mut mismatch = false; - for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { - if i != dim && d1 != d2 { - mismatch = true; - break; - } - } - mismatch - }; - if mismatch { - Err(Error::ShapeMismatchBinaryOp { - op: "scatter-add (self, src)", - lhs: self.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - if indexes.dims() != source.dims() { - Err(Error::ShapeMismatchBinaryOp { - op: "scatter-add (indexes, src)", - lhs: indexes.shape().clone(), - rhs: source.shape().clone(), - } - .bt())? - } - let storage = self.storage().scatter_add( - self.layout(), + self.scatter_checks(indexes, source, dim)?; + let shape = self.shape(); + let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? }; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let layout = Layout::contiguous(shape); + storage.scatter_add( + &layout, &indexes.storage(), indexes.layout(), &source.storage(), @@ -1446,6 +1449,23 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + pub fn scatter_add_set(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> { + if self.same_storage(source) { + crate::bail!("cannot use slice_set when self and src share their storage") + } + let dim = dim.to_index(self.shape(), "scatter-add-set")?; + self.scatter_checks(indexes, source, dim)?; + self.storage_mut().scatter_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + Ok(()) + } + /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. pub fn slice_scatter(&self, src: &Self, dim: D, start: usize) -> Result { let dim = dim.to_index(self.shape(), "slice-scatter")?; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 7e2d41ba..8767bc8c 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1087,6 +1087,18 @@ fn scatter(device: &Device) -> Result<()> { [1.0, 1.0, 1.0] ] ); + init.scatter_set(&ids, &t, 0)?; + assert_eq!( + init.to_vec2::()?, + &[ + [0.0, 10.0, 5.0], + [1.0, 1.0, 8.0], + [9.0, 1.0, 2.0], + [6.0, 7.0, 1.0], + [1.0, 4.0, 11.0], + [1.0, 1.0, 1.0] + ] + ); Ok(()) } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 9f689a07..de1b1053 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1457,7 +1457,7 @@ pub fn call_scatter( dim: usize, input: BufferOffset, ids: BufferOffset, - output: &Buffer, + output: BufferOffset, ) -> Result<(), MetalKernelError> { let left_size: usize = src_shape[..dim].iter().product(); let right_size: usize = src_shape[dim + 1..].iter().product(); @@ -1482,7 +1482,7 @@ pub fn call_scatter( dst_dim_size, &input, &ids, - output + &output ) ); @@ -1490,7 +1490,7 @@ pub fn call_scatter( encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read); encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) }