mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Adding indexing.
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:

committed by
Nicolas Patry

parent
df6814f34e
commit
f82bf2d915
@ -479,28 +479,40 @@ impl BackendStorage for MetalStorage {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn index_select(
|
||||
&self,
|
||||
_ids: &Self,
|
||||
_src_l: &Layout,
|
||||
_ids_l: &Layout,
|
||||
_dim: usize,
|
||||
) -> Result<Self> {
|
||||
todo!("Index select");
|
||||
// let ids_shape = ids_l.shape();
|
||||
// let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
// let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
// let src_dim_size = src_l.dims()[dim];
|
||||
// let ids_dim_size = ids_shape.elem_count();
|
||||
// let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||
// let dtype = self.dtype;
|
||||
// let device = self.device();
|
||||
// let buffer = device.new_buffer(dst_el, dtype);
|
||||
// Ok(Self {
|
||||
// buffer,
|
||||
// device: device.clone(),
|
||||
// dtype,
|
||||
// })
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
assert!(src_l.is_contiguous());
|
||||
assert!(ids_l.is_contiguous());
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let ids_el = ids_l.shape().elem_count();
|
||||
let dst_el = ids_el * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||
};
|
||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
&ids.buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
|
Reference in New Issue
Block a user