mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Adding gather op.
This commit is contained in:
@ -826,8 +826,38 @@ impl BackendStorage for MetalStorage {
|
||||
crate::bail!("upsample_nearest2d metal")
|
||||
}
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||
crate::bail!("gather metal")
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let ids_el = ids_l.dims()[dim];
|
||||
let dst_el = ids_l.shape().elem_count();
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||
(left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_gather(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
&ids.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
|
Reference in New Issue
Block a user