Add support for strided index-select on Metal (#1909)

* initial implementation

* use correct index, but still not breaking like it should have...

* fix test
This commit is contained in:
Thomas Santerre
2024-03-22 02:30:02 -04:00
committed by GitHub
parent 6708870e63
commit fee33b45c2
4 changed files with 129 additions and 23 deletions

View File

@ -2,9 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels::CallConvTranspose2dCfg;
use candle_metal_kernels::Kernels;
use candle_metal_kernels::{self, CallConvTranspose2dCfg};
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
@ -1348,12 +1347,8 @@ impl BackendStorage for MetalStorage {
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
if !(src_l.is_contiguous()
&& src_l.start_offset() == 0
&& ids_l.is_contiguous()
&& ids_l.start_offset() == 0)
{
crate::bail!("Metal strided index_select not implemented");
if !ids_l.is_contiguous() {
crate::bail!("Metal index_select requires contiguous ids")
}
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
@ -1364,6 +1359,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
@ -1382,8 +1379,13 @@ impl BackendStorage for MetalStorage {
src_l.dims(),
ids_el,
dim,
src_l.is_contiguous(),
src_l.dims(),
src_l.stride(),
&self.buffer,
src_l.start_offset() * dtype.size_in_bytes(),
&ids.buffer,
ids_l.start_offset() * ids.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;