mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Adding gather op.
This commit is contained in:
@ -826,8 +826,38 @@ impl BackendStorage for MetalStorage {
|
|||||||
crate::bail!("upsample_nearest2d metal")
|
crate::bail!("upsample_nearest2d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
crate::bail!("gather metal")
|
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
||||||
|
Some(o12) => o12,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||||
|
};
|
||||||
|
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||||
|
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
|
let ids_el = ids_l.dims()[dim];
|
||||||
|
let dst_el = ids_l.shape().elem_count();
|
||||||
|
let dtype = self.dtype;
|
||||||
|
let device = self.device();
|
||||||
|
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||||
|
let name = match (ids.dtype, self.dtype) {
|
||||||
|
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||||
|
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||||
|
(left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"),
|
||||||
|
};
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_gather(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
src_l.dims(),
|
||||||
|
ids_el,
|
||||||
|
dim,
|
||||||
|
&self.buffer,
|
||||||
|
&ids.buffer,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(
|
fn scatter_add(
|
||||||
|
@ -1,6 +1,34 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
|
METAL_FUNC void index(
|
||||||
|
constant size_t &dst_size,
|
||||||
|
constant size_t &left_size,
|
||||||
|
constant size_t &src_dim_size,
|
||||||
|
constant size_t &right_size,
|
||||||
|
constant size_t &ids_size,
|
||||||
|
const device TYPENAME *input,
|
||||||
|
const device INDEX_TYPENAME *input_ids,
|
||||||
|
device TYPENAME *output,
|
||||||
|
uint tid [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
if (tid >= dst_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t id_i = (tid / right_size) % ids_size;
|
||||||
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||||
|
const size_t right_rank_i = tid % right_size;
|
||||||
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
|
/*
|
||||||
|
// Force prevent out of bounds indexing
|
||||||
|
// since there doesn't seem to be a good way to force crash
|
||||||
|
// No need to check for zero we're only allowing unsized.
|
||||||
|
*/
|
||||||
|
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||||
|
output[tid] = input[src_i];
|
||||||
|
}
|
||||||
|
|
||||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &dst_size, \
|
constant size_t &dst_size, \
|
||||||
@ -11,22 +39,52 @@ kernel void NAME( \
|
|||||||
const device TYPENAME *input, \
|
const device TYPENAME *input, \
|
||||||
const device INDEX_TYPENAME *input_ids, \
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
device TYPENAME *output, \
|
device TYPENAME *output, \
|
||||||
uint gid [[ thread_position_in_grid ]] \
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
) { \
|
) { \
|
||||||
if (gid >= dst_size) { \
|
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||||
return; \
|
}
|
||||||
} \
|
|
||||||
const size_t id_i = (gid / right_size) % ids_size; \
|
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
const size_t right_rank_i = gid % right_size; \
|
METAL_FUNC void gather(
|
||||||
const size_t left_rank_i = gid / right_size / ids_size; \
|
constant size_t &dst_size,
|
||||||
/* \
|
constant size_t &left_size,
|
||||||
// Force prevent out of bounds indexing \
|
constant size_t &src_dim_size,
|
||||||
// since there doesn't seem to be a good way to force crash \
|
constant size_t &right_size,
|
||||||
// No need to check for zero we're only allowing unsized. \
|
constant size_t &ids_size,
|
||||||
*/ \
|
const device TYPENAME *input,
|
||||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
|
const device INDEX_TYPENAME *input_ids,
|
||||||
output[gid] = input[src_i]; \
|
device TYPENAME *output,
|
||||||
|
uint tid [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
if (tid >= dst_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||||
|
const size_t right_rank_i = tid % right_size;
|
||||||
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
|
/*
|
||||||
|
// Force prevent out of bounds indexing
|
||||||
|
// since there doesn't seem to be a good way to force crash
|
||||||
|
// No need to check for zero we're only allowing unsized.
|
||||||
|
*/
|
||||||
|
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||||
|
output[tid] = input[src_i];
|
||||||
|
}
|
||||||
|
|
||||||
|
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &dst_size, \
|
||||||
|
constant size_t &left_size, \
|
||||||
|
constant size_t &src_dim_size, \
|
||||||
|
constant size_t &right_size, \
|
||||||
|
constant size_t &ids_size, \
|
||||||
|
const device TYPENAME *input, \
|
||||||
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -76,6 +134,8 @@ kernel void FN_NAME( \
|
|||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint, float)
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
INDEX_OP(is_u32_f16, uint, half)
|
INDEX_OP(is_u32_f16, uint, half)
|
||||||
|
GATHER_OP(gather_u32_f32, uint, float)
|
||||||
|
GATHER_OP(gather_u32_f16, uint, half)
|
||||||
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
@ -1010,6 +1010,56 @@ pub fn call_index_select(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_gather(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
shape: &[usize],
|
||||||
|
ids_size: usize,
|
||||||
|
dim: usize,
|
||||||
|
input: &Buffer,
|
||||||
|
ids: &Buffer,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
|
let src_dim_size = shape[dim];
|
||||||
|
let dst_el = ids_size * left_size * right_size;
|
||||||
|
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
dst_el,
|
||||||
|
left_size,
|
||||||
|
src_dim_size,
|
||||||
|
right_size,
|
||||||
|
ids_size,
|
||||||
|
input,
|
||||||
|
ids,
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub enum Value {
|
pub enum Value {
|
||||||
USize(usize),
|
USize(usize),
|
||||||
|
Reference in New Issue
Block a user