Adding indexing.

Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-11-10 01:58:51 +01:00
parent 9a2784b8ab
commit 02c2ec2c71
4 changed files with 191 additions and 138 deletions

View File

@ -479,28 +479,40 @@ impl BackendStorage for MetalStorage {
todo!() todo!()
} }
fn index_select( fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
&self, assert!(src_l.is_contiguous());
_ids: &Self, assert!(ids_l.is_contiguous());
_src_l: &Layout, let left_size: usize = src_l.dims()[..dim].iter().product();
_ids_l: &Layout, let right_size: usize = src_l.dims()[dim + 1..].iter().product();
_dim: usize, let ids_el = ids_l.shape().elem_count();
) -> Result<Self> { let dst_el = ids_el * left_size * right_size;
todo!("Index select"); let dtype = self.dtype;
// let ids_shape = ids_l.shape(); let device = self.device();
// let left_size: usize = src_l.dims()[..dim].iter().product(); let mut buffer = device.new_buffer(dst_el, dtype);
// let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let name = match (ids.dtype, self.dtype) {
// let src_dim_size = src_l.dims()[dim]; (DType::U32, DType::F32) => "is_u32_f32",
// let ids_dim_size = ids_shape.elem_count(); (left, right) => todo!("index select metal {left:?} {right:?}"),
// let dst_el = ids_shape.elem_count() * left_size * right_size; };
// let dtype = self.dtype; let command_buffer = self.device.command_queue.new_command_buffer();
// let device = self.device(); candle_metal_kernels::call_index_select(
// let buffer = device.new_buffer(dst_el, dtype); &device.device,
// Ok(Self { &command_buffer,
// buffer, &self.device.kernels,
// device: device.clone(), name,
// dtype, src_l.dims(),
// }) ids_el,
dim,
&self.buffer,
&ids.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
command_buffer.commit();
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
} }
fn index_add( fn index_add(

View File

@ -1,39 +1,36 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
kernel void is_u32_f32( # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
constant size_t &dst_size, kernel void NAME( \
constant size_t &left_size, constant size_t &dst_size, \
constant size_t &src_dim_size, constant size_t &left_size, \
constant size_t &right_size, constant size_t &src_dim_size, \
constant size_t &ids_size, constant size_t &right_size, \
constant size_t &ids_size, \
const device float *input, const device TYPENAME *input, \
const device uint *input_ids, const device INDEX_TYPENAME *input_ids, \
device float *output, device TYPENAME *output, \
uint gid [[ thread_position_in_grid ]] \
uint gid [[ thread_position_in_grid ]] ) { \
) { if (gid >= dst_size) { \
return; \
if (gid >= dst_size) { } \
return; const size_t id_i = gid / right_size / left_size; \
} const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid % left_size; \
const size_t id_i = gid / right_size / left_size; /* \
const size_t right_rank_i = gid % right_size; // Force prevent out of bounds indexing \
const size_t left_rank_i = gid % left_size; // since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
// Force prevent out of bounds indexing */ \
// since there doesn't seem to be a good way to force crash const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
// No need to check for zero we're only allowing unsized. const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1)); output[gid] = input[src_i]; \
const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i;
output[gid] = input[src_i];
} }
template <typename T, typename I> template <typename T, typename I>
void index_add( void index_add(
device I *ids [[buffer(0)]], device I *ids [[buffer(0)]],
@ -82,6 +79,7 @@ kernel void FN_NAME( \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ ) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
INDEX_OP(is_u32_f32, uint, float)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16) IA_OP(bfloat, int64_t, ia_i64_bf16)

View File

@ -690,6 +690,63 @@ pub fn call_where_cond_strided(
Ok(()) Ok(())
} }
pub fn call_index_select(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
ids_size: usize,
dim: usize,
input: &Buffer,
ids: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
let func = kernels.load_function(device, Source::Indexing, name)?;
let pipeline = device
.new_compute_pipeline_state_with_function(&func)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
dst_el,
left_size,
src_dim_size,
right_size,
ids_size,
input,
ids,
output
)
);
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -1003,61 +1060,32 @@ mod tests {
dim: usize, dim: usize,
) -> Vec<T> { ) -> Vec<T> {
let device = Device::system_default().expect("no device found"); let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids.len() * left_size * right_size;
let ids_size = ids.len();
let function = library.get_function("is_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let embeddings_buffer = new_buffer(&device, &embeddings); let embeddings_buffer = new_buffer(&device, &embeddings);
let ids_buffer = new_buffer(&device, &ids); let ids_buffer = new_buffer(&device, &ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
set_params!( let kernels = Kernels::new();
encoder, call_index_select(
( &device,
dst_el, &command_buffer,
left_size, &kernels,
src_dim_size, "is_u32_f32",
right_size, shape,
ids_size, ids.len(),
&embeddings_buffer, dim,
&ids_buffer, &embeddings_buffer,
&mut dst_buffer &ids_buffer,
) &mut dst_buffer,
); )
.unwrap();
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64);
let grid_size = MTLSize {
width: (dst_el as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
println!("{width:?} - {:?}", grid_size);
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();

View File

@ -18,45 +18,55 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 256; constant int THREADGROUP_SIZE = 256;
kernel void fast_sum_float( # define REDUCE(FN, NAME, TYPENAME) \
constant size_t &src_numel, kernel void NAME( \
constant size_t &el_to_sum_per_block, constant size_t &src_numel, \
device const float *src, constant size_t &el_to_sum_per_block, \
device float *dst, device const TYPENAME *src, \
uint id [[ thread_position_in_grid ]], device TYPENAME *dst, \
uint tid [[ thread_index_in_threadgroup ]], uint id [[ thread_position_in_grid ]], \
uint dst_id [[ threadgroup_position_in_grid ]], uint tid [[ thread_index_in_threadgroup ]], \
uint blockDim [[ threads_per_threadgroup ]] uint dst_id [[ threadgroup_position_in_grid ]], \
) { uint blockDim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
// Elements summed in this block range from dst_id * el_to_sum_per_block shared_memory[tid] = 0; \
// to (dst_id + 1) * el_to_sum_per_block. /* \
size_t start_idx = dst_id * el_to_sum_per_block; // Elements summed in this block range from dst_id * el_to_sum_per_block \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); // to (dst_id + 1) * el_to_sum_per_block. \
size_t idx = start_idx + tid; */ \
size_t start_idx = dst_id * el_to_sum_per_block; \
while (idx < stop_idx) { size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
// TODO: Fast version for the contiguous case. size_t idx = start_idx + tid; \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); while (idx < stop_idx) { \
shared_memory[tid] += src[idx]; /* \
idx += blockDim; // TODO: Fast version for the contiguous case. \
} // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
threadgroup_barrier(mem_flags::mem_none); TYPENAME x = shared_memory[tid]; \
TYPENAME y = src[idx]; \
// reduction in shared memory shared_memory[tid] = FN; \
for (uint s = blockDim / 2; s > 0; s >>= 1) { idx += blockDim; \
if (tid < s) { } \
shared_memory[tid] += shared_memory[tid + s]; \
} threadgroup_barrier(mem_flags::mem_none); \
threadgroup_barrier(mem_flags::mem_none); \
} /* \
// reduction in shared memory \
dst[dst_id] = shared_memory[0]; */ \
} for (uint s = blockDim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
TYPENAME x = shared_memory[tid]; \
TYPENAME y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
kernel void softmax_float( kernel void softmax_float(
constant size_t &src_numel, constant size_t &src_numel,
@ -122,3 +132,8 @@ kernel void softmax_float(
idx += blockDim; idx += blockDim;
} }
} }
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)