Adding indexing.

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-11-10 01:58:51 +01:00
committed by Nicolas Patry
parent df6814f34e
commit f82bf2d915
4 changed files with 191 additions and 138 deletions

View File

@ -479,28 +479,40 @@ impl BackendStorage for MetalStorage {
todo!()
}
fn index_select(
&self,
_ids: &Self,
_src_l: &Layout,
_ids_l: &Layout,
_dim: usize,
) -> Result<Self> {
todo!("Index select");
// let ids_shape = ids_l.shape();
// let left_size: usize = src_l.dims()[..dim].iter().product();
// let right_size: usize = src_l.dims()[dim + 1..].iter().product();
// let src_dim_size = src_l.dims()[dim];
// let ids_dim_size = ids_shape.elem_count();
// let dst_el = ids_shape.elem_count() * left_size * right_size;
// let dtype = self.dtype;
// let device = self.device();
// let buffer = device.new_buffer(dst_el, dtype);
// Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// })
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
assert!(src_l.is_contiguous());
assert!(ids_l.is_contiguous());
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let ids_el = ids_l.shape().elem_count();
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let mut buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(left, right) => todo!("index select metal {left:?} {right:?}"),
};
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
&self.device.kernels,
name,
src_l.dims(),
ids_el,
dim,
&self.buffer,
&ids.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn index_add(