mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the scatter in place ops. (#2923)
* Add the scatter_set op. * Metal op. * Cuda version. * Merge the checks. * Add the actual ops.
This commit is contained in:
@ -71,24 +71,27 @@ pub trait BackendStorage: Sized {
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter(
|
||||
&self,
|
||||
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
) -> Result<()>;
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
) -> Result<()>;
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
|
@ -7,7 +7,7 @@ use rayon::prelude::*;
|
||||
|
||||
mod utils;
|
||||
pub use utils::{
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
@ -591,12 +591,20 @@ impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: IntDType, M: ElemUpdate> Map2 for Scatter<'_, I, M> {
|
||||
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
||||
const OP: &'static str = "scatter";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
fn f<T: WithDType>(
|
||||
&self,
|
||||
dst: &mut [T],
|
||||
dst_l: &Layout,
|
||||
src: &[T],
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &mut dst[o1..o2],
|
||||
};
|
||||
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
@ -604,7 +612,7 @@ impl<I: IntDType, M: ElemUpdate> Map2 for Scatter<'_, I, M> {
|
||||
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
let dst_dims = l1.dims();
|
||||
let dst_dims = dst_l.dims();
|
||||
let dst_dim_len = dst_dims[dim];
|
||||
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
|
||||
@ -638,7 +646,7 @@ impl<I: IntDType, M: ElemUpdate> Map2 for Scatter<'_, I, M> {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dst)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -2412,15 +2420,15 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
@ -2429,15 +2437,15 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
|
@ -58,6 +58,30 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
|
||||
|
||||
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
@ -2,7 +2,7 @@
|
||||
//!
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -507,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -529,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
@ -536,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[0];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
@ -544,7 +548,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
barg!(builder, ids);
|
||||
barg!(builder, ids_dim_sz);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -557,7 +561,7 @@ impl Map2InPlace for Scatter<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -579,6 +583,10 @@ impl Map2InPlace for Scatter<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
@ -586,13 +594,13 @@ impl Map2InPlace for Scatter<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -605,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -627,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
@ -634,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -1886,35 +1898,29 @@ impl BackendStorage for CudaStorage {
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
Scatter(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -1928,7 +1934,7 @@ impl BackendStorage for CudaStorage {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
use crate::{Layout, Result, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
@ -96,7 +96,7 @@ pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -105,19 +105,19 @@ pub trait Map2InPlace {
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
|
@ -128,27 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
|
@ -132,27 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
|
@ -1426,18 +1426,16 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
}
|
||||
|
||||
fn scatter(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
@ -1458,6 +1456,7 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter(
|
||||
@ -1470,24 +1469,22 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
@ -1508,6 +1505,7 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter(
|
||||
@ -1520,10 +1518,10 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
|
@ -628,60 +628,56 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter(
|
||||
&self,
|
||||
pub(crate) fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-set")?;
|
||||
self.same_device(source, "scatter-set")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
|
@ -1354,8 +1354,7 @@ impl Tensor {
|
||||
self.index_select(ids, 0)
|
||||
}
|
||||
|
||||
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter")?;
|
||||
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
@ -1386,8 +1385,19 @@ impl Tensor {
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().scatter(
|
||||
self.layout(),
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_set(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
@ -1400,40 +1410,33 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_set(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
true
|
||||
} else {
|
||||
let mut mismatch = false;
|
||||
for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
|
||||
if i != dim && d1 != d2 {
|
||||
mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
mismatch
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
self.layout(),
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_add(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
@ -1446,6 +1449,23 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
|
@ -1087,6 +1087,18 @@ fn scatter(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
init.scatter_set(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
init.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1457,7 +1457,7 @@ pub fn call_scatter(
|
||||
dim: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
@ -1482,7 +1482,7 @@ pub fn call_scatter(
|
||||
dst_dim_size,
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
&output
|
||||
)
|
||||
);
|
||||
|
||||
@ -1490,7 +1490,7 @@ pub fn call_scatter(
|
||||
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user