mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add support for strided index-select on Metal (#1909)
* initial implementation * use correct index, but still not breaking like it should have... * fix test
This commit is contained in:
@ -1067,8 +1067,13 @@ pub fn call_index_select(
|
||||
shape: &[usize],
|
||||
ids_size: usize,
|
||||
dim: usize,
|
||||
contiguous: bool,
|
||||
src_dims: &[usize],
|
||||
src_strides: &[usize],
|
||||
input: &Buffer,
|
||||
src_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1090,8 +1095,11 @@ pub fn call_index_select(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
input,
|
||||
ids,
|
||||
contiguous,
|
||||
src_dims,
|
||||
src_strides,
|
||||
(input, src_offset),
|
||||
(ids, ids_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
Reference in New Issue
Block a user