Add support for strided index-select on Metal (#1909)

* initial implementation

* use correct index, but still not breaking like it should have...

* fix test
This commit is contained in:
Thomas Santerre
2024-03-22 02:30:02 -04:00
committed by GitHub
parent 6708870e63
commit fee33b45c2
4 changed files with 129 additions and 23 deletions

View File

@ -1067,8 +1067,13 @@ pub fn call_index_select(
shape: &[usize],
ids_size: usize,
dim: usize,
contiguous: bool,
src_dims: &[usize],
src_strides: &[usize],
input: &Buffer,
src_offset: usize,
ids: &Buffer,
ids_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
@ -1090,8 +1095,11 @@ pub fn call_index_select(
src_dim_size,
right_size,
ids_size,
input,
ids,
contiguous,
src_dims,
src_strides,
(input, src_offset),
(ids, ids_offset),
output
)
);