mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the scatter in place ops. (#2923)
* Add the scatter_set op. * Metal op. * Cuda version. * Merge the checks. * Add the actual ops.
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
//!
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -507,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -529,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
@ -536,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[0];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
@ -544,7 +548,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
barg!(builder, ids);
|
||||
barg!(builder, ids_dim_sz);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -557,7 +561,7 @@ impl Map2InPlace for Scatter<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -579,6 +583,10 @@ impl Map2InPlace for Scatter<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
@ -586,13 +594,13 @@ impl Map2InPlace for Scatter<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -605,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -627,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
@ -634,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -1886,35 +1898,29 @@ impl BackendStorage for CudaStorage {
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
Scatter(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -1928,7 +1934,7 @@ impl BackendStorage for CudaStorage {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
use crate::{Layout, Result, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
@ -96,7 +96,7 @@ pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -105,19 +105,19 @@ pub trait Map2InPlace {
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user