Adding gather op.

This commit is contained in:
Nicolas Patry
2023-12-17 23:34:12 +01:00
parent e4b0cc59f5
commit 586b6f6fff
3 changed files with 157 additions and 17 deletions

View File

@ -826,8 +826,38 @@ impl BackendStorage for MetalStorage {
crate::bail!("upsample_nearest2d metal")
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
crate::bail!("gather metal")
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
Some(o12) => o12,
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
};
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.dims()[dim];
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
(left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_gather(
&device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
ids_el,
dim,
&self.buffer,
&ids.buffer,
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, device.clone(), dtype))
}
fn scatter_add(