mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Adding indexing.
Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:

committed by
Nicolas Patry

parent
df6814f34e
commit
f82bf2d915
@ -479,28 +479,40 @@ impl BackendStorage for MetalStorage {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_select(
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
&self,
|
assert!(src_l.is_contiguous());
|
||||||
_ids: &Self,
|
assert!(ids_l.is_contiguous());
|
||||||
_src_l: &Layout,
|
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||||
_ids_l: &Layout,
|
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
_dim: usize,
|
let ids_el = ids_l.shape().elem_count();
|
||||||
) -> Result<Self> {
|
let dst_el = ids_el * left_size * right_size;
|
||||||
todo!("Index select");
|
let dtype = self.dtype;
|
||||||
// let ids_shape = ids_l.shape();
|
let device = self.device();
|
||||||
// let left_size: usize = src_l.dims()[..dim].iter().product();
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
// let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
let name = match (ids.dtype, self.dtype) {
|
||||||
// let src_dim_size = src_l.dims()[dim];
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
// let ids_dim_size = ids_shape.elem_count();
|
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||||
// let dst_el = ids_shape.elem_count() * left_size * right_size;
|
};
|
||||||
// let dtype = self.dtype;
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
// let device = self.device();
|
candle_metal_kernels::call_index_select(
|
||||||
// let buffer = device.new_buffer(dst_el, dtype);
|
&device.device,
|
||||||
// Ok(Self {
|
&command_buffer,
|
||||||
// buffer,
|
&self.device.kernels,
|
||||||
// device: device.clone(),
|
name,
|
||||||
// dtype,
|
src_l.dims(),
|
||||||
// })
|
ids_el,
|
||||||
|
dim,
|
||||||
|
&self.buffer,
|
||||||
|
&ids.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_add(
|
fn index_add(
|
||||||
|
@ -1,39 +1,36 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
kernel void is_u32_f32(
|
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
constant size_t &dst_size,
|
kernel void NAME( \
|
||||||
constant size_t &left_size,
|
constant size_t &dst_size, \
|
||||||
constant size_t &src_dim_size,
|
constant size_t &left_size, \
|
||||||
constant size_t &right_size,
|
constant size_t &src_dim_size, \
|
||||||
constant size_t &ids_size,
|
constant size_t &right_size, \
|
||||||
|
constant size_t &ids_size, \
|
||||||
const device float *input,
|
const device TYPENAME *input, \
|
||||||
const device uint *input_ids,
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
device float *output,
|
device TYPENAME *output, \
|
||||||
|
uint gid [[ thread_position_in_grid ]] \
|
||||||
uint gid [[ thread_position_in_grid ]]
|
) { \
|
||||||
) {
|
if (gid >= dst_size) { \
|
||||||
|
return; \
|
||||||
if (gid >= dst_size) {
|
} \
|
||||||
return;
|
const size_t id_i = gid / right_size / left_size; \
|
||||||
}
|
const size_t right_rank_i = gid % right_size; \
|
||||||
|
const size_t left_rank_i = gid % left_size; \
|
||||||
const size_t id_i = gid / right_size / left_size;
|
/* \
|
||||||
const size_t right_rank_i = gid % right_size;
|
// Force prevent out of bounds indexing \
|
||||||
const size_t left_rank_i = gid % left_size;
|
// since there doesn't seem to be a good way to force crash \
|
||||||
|
// No need to check for zero we're only allowing unsized. \
|
||||||
// Force prevent out of bounds indexing
|
*/ \
|
||||||
// since there doesn't seem to be a good way to force crash
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
|
||||||
// No need to check for zero we're only allowing unsized.
|
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
|
||||||
const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1));
|
output[gid] = input[src_i]; \
|
||||||
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i;
|
|
||||||
|
|
||||||
output[gid] = input[src_i];
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
void index_add(
|
void index_add(
|
||||||
device I *ids [[buffer(0)]],
|
device I *ids [[buffer(0)]],
|
||||||
@ -82,6 +79,7 @@ kernel void FN_NAME( \
|
|||||||
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
|
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_OP(is_u32_f32, uint, float)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
IA_OP(bfloat, int64_t, ia_i64_bf16)
|
||||||
|
@ -690,6 +690,63 @@ pub fn call_where_cond_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_index_select(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
shape: &[usize],
|
||||||
|
ids_size: usize,
|
||||||
|
dim: usize,
|
||||||
|
input: &Buffer,
|
||||||
|
ids: &Buffer,
|
||||||
|
output: &mut Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
|
let src_dim_size = shape[dim];
|
||||||
|
let dst_el = ids_size * left_size * right_size;
|
||||||
|
|
||||||
|
let func = kernels.load_function(device, Source::Indexing, name)?;
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&func)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
dst_el,
|
||||||
|
left_size,
|
||||||
|
src_dim_size,
|
||||||
|
right_size,
|
||||||
|
ids_size,
|
||||||
|
input,
|
||||||
|
ids,
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: (dst_el as u64 + width - 1) / width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -1003,61 +1060,32 @@ mod tests {
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
let device = Device::system_default().expect("no device found");
|
let device = Device::system_default().expect("no device found");
|
||||||
let options = CompileOptions::new();
|
|
||||||
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
|
||||||
|
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
|
||||||
let src_dim_size = shape[dim];
|
|
||||||
let dst_el = ids.len() * left_size * right_size;
|
|
||||||
let ids_size = ids.len();
|
|
||||||
|
|
||||||
let function = library.get_function("is_u32_f32", None).unwrap();
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||||
let ids_buffer = new_buffer(&device, &ids);
|
let ids_buffer = new_buffer(&device, &ids);
|
||||||
|
|
||||||
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
|
let dst_el = ids.len() * left_size * right_size;
|
||||||
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||||
|
|
||||||
set_params!(
|
let kernels = Kernels::new();
|
||||||
encoder,
|
call_index_select(
|
||||||
(
|
&device,
|
||||||
dst_el,
|
&command_buffer,
|
||||||
left_size,
|
&kernels,
|
||||||
src_dim_size,
|
"is_u32_f32",
|
||||||
right_size,
|
shape,
|
||||||
ids_size,
|
ids.len(),
|
||||||
&embeddings_buffer,
|
dim,
|
||||||
&ids_buffer,
|
&embeddings_buffer,
|
||||||
&mut dst_buffer
|
&ids_buffer,
|
||||||
)
|
&mut dst_buffer,
|
||||||
);
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
|
|
||||||
let grid_size = MTLSize {
|
|
||||||
width: (dst_el as u64 + width - 1) / width,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("{width:?} - {:?}", grid_size);
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
@ -18,45 +18,55 @@ METAL_FUNC uint get_strided_index(
|
|||||||
|
|
||||||
constant int THREADGROUP_SIZE = 256;
|
constant int THREADGROUP_SIZE = 256;
|
||||||
|
|
||||||
kernel void fast_sum_float(
|
# define REDUCE(FN, NAME, TYPENAME) \
|
||||||
constant size_t &src_numel,
|
kernel void NAME( \
|
||||||
constant size_t &el_to_sum_per_block,
|
constant size_t &src_numel, \
|
||||||
device const float *src,
|
constant size_t &el_to_sum_per_block, \
|
||||||
device float *dst,
|
device const TYPENAME *src, \
|
||||||
uint id [[ thread_position_in_grid ]],
|
device TYPENAME *dst, \
|
||||||
uint tid [[ thread_index_in_threadgroup ]],
|
uint id [[ thread_position_in_grid ]], \
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]],
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
uint blockDim [[ threads_per_threadgroup ]]
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
) {
|
uint blockDim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE];
|
\
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
shared_memory[tid] = 0;
|
\
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block
|
shared_memory[tid] = 0; \
|
||||||
// to (dst_id + 1) * el_to_sum_per_block.
|
/* \
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block;
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
size_t idx = start_idx + tid;
|
*/ \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
while (idx < stop_idx) {
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||||
// TODO: Fast version for the contiguous case.
|
size_t idx = start_idx + tid; \
|
||||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
while (idx < stop_idx) { \
|
||||||
shared_memory[tid] += src[idx];
|
/* \
|
||||||
idx += blockDim;
|
// TODO: Fast version for the contiguous case. \
|
||||||
}
|
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
|
*/ \
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
TYPENAME x = shared_memory[tid]; \
|
||||||
|
TYPENAME y = src[idx]; \
|
||||||
// reduction in shared memory
|
shared_memory[tid] = FN; \
|
||||||
for (uint s = blockDim / 2; s > 0; s >>= 1) {
|
idx += blockDim; \
|
||||||
if (tid < s) {
|
} \
|
||||||
shared_memory[tid] += shared_memory[tid + s];
|
\
|
||||||
}
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
\
|
||||||
}
|
/* \
|
||||||
|
// reduction in shared memory \
|
||||||
dst[dst_id] = shared_memory[0];
|
*/ \
|
||||||
}
|
for (uint s = blockDim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s) { \
|
||||||
|
TYPENAME x = shared_memory[tid]; \
|
||||||
|
TYPENAME y = shared_memory[tid + s]; \
|
||||||
|
shared_memory[tid] = FN; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
dst[dst_id] = shared_memory[0]; \
|
||||||
|
} \
|
||||||
|
|
||||||
kernel void softmax_float(
|
kernel void softmax_float(
|
||||||
constant size_t &src_numel,
|
constant size_t &src_numel,
|
||||||
@ -122,3 +132,8 @@ kernel void softmax_float(
|
|||||||
idx += blockDim;
|
idx += blockDim;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
REDUCE(x + y, fast_sum_float, float)
|
||||||
|
REDUCE(x * y, fast_mul_float, float)
|
||||||
|
REDUCE(max(x, y), fast_max_float, float)
|
||||||
|
Reference in New Issue
Block a user