Index add.

This commit is contained in:
Nicolas Patry
2023-12-18 10:46:01 +01:00
parent 6a3ca7da0c
commit 8bd3d6b94b
3 changed files with 151 additions and 63 deletions

View File

@ -951,14 +951,49 @@ impl BackendStorage for MetalStorage {
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
l: &Layout,
ids: &Self,
ids_l: &Layout,
src: &Self,
src_l: &Layout,
dim: usize,
) -> Result<Self> {
crate::bail!("index_add metal")
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
self.copy_strided_src(&mut acc, 0, l)?;
let (ids_offset, _) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let src_offset = match src_l.contiguous_offsets() {
Some((o1, _)) => o1,
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "ia_u32_f32",
_ => Err(MetalError::UnexpectedDType {
msg: "index-add ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
};
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_index_add(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
l.dims(),
ids_l.dims(),
dim,
&src.buffer,
src_offset * src.dtype.size_in_bytes(),
&ids.buffer,
ids_offset * ids.dtype.size_in_bytes(),
&acc.buffer,
)
.map_err(MetalError::from)?;
Ok(acc)
}
fn matmul(
&self,