mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add support for strided index-select on Metal (#1909)
* initial implementation * use correct index, but still not breaking like it should have... * fix test
This commit is contained in:
@ -2,9 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels::CallConvTranspose2dCfg;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use candle_metal_kernels::{self, CallConvTranspose2dCfg};
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -1348,12 +1347,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
if !(src_l.is_contiguous()
|
||||
&& src_l.start_offset() == 0
|
||||
&& ids_l.is_contiguous()
|
||||
&& ids_l.start_offset() == 0)
|
||||
{
|
||||
crate::bail!("Metal strided index_select not implemented");
|
||||
if !ids_l.is_contiguous() {
|
||||
crate::bail!("Metal index_select requires contiguous ids")
|
||||
}
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
@ -1364,6 +1359,8 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::BF16) => "is_u8_bf16",
|
||||
(DType::U8, DType::F32) => "is_u8_f32",
|
||||
(DType::U8, DType::F16) => "is_u8_f16",
|
||||
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
@ -1382,8 +1379,13 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
src_l.is_contiguous(),
|
||||
src_l.dims(),
|
||||
src_l.stride(),
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_l.start_offset() * ids.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
Reference in New Issue
Block a user