Add the scatter op. (#2921)

* Add the scatter op.

* Backprop support.

* Cuda support.
This commit is contained in:
Laurent Mazare
2025-04-25 21:46:58 +02:00
committed by GitHub
parent 3aeb9575c7
commit 3827685524
15 changed files with 429 additions and 19 deletions

View File

@ -552,6 +552,54 @@ impl Map2InPlace for IndexAdd<'_> {
}
}
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map2InPlace for Scatter<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
dst: &mut CudaSlice<T>,
dst_shape: &Shape,
src: &CudaSlice<T>,
src_l: &Layout,
dev: &CudaDevice,
) -> Result<()> {
let ids = &self.0;
let ids_l = &self.1;
let dim = self.2;
let (ids_o1, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let (name, (ids, _guard)) = match &ids.slice {
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
_ => Err(CudaError::UnexpectedDType {
msg: "scatter ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
};
let left_sz: usize = src_l.dims()[..dim].iter().product();
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_sz = src_l.dims()[dim];
let dst_dim_sz = dst_shape.dims()[dim];
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
let mut builder = func.builder();
barg!(builder, ids);
builder.arg(&src);
builder.arg(dst);
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
}
struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
impl Map2InPlace for ScatterAdd<'_> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1838,6 +1886,21 @@ impl BackendStorage for CudaStorage {
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
Ok(Self { slice, device })
}
fn scatter(
&self,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
let device = self.device().clone();
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
self.copy_strided_src(&mut acc, 0, l)?;
Scatter(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
Ok(acc)
}
fn scatter_add(
&self,
l: &Layout,