diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 68a96672..ed592240 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -479,28 +479,40 @@ impl BackendStorage for MetalStorage { todo!() } - fn index_select( - &self, - _ids: &Self, - _src_l: &Layout, - _ids_l: &Layout, - _dim: usize, - ) -> Result { - todo!("Index select"); - // let ids_shape = ids_l.shape(); - // let left_size: usize = src_l.dims()[..dim].iter().product(); - // let right_size: usize = src_l.dims()[dim + 1..].iter().product(); - // let src_dim_size = src_l.dims()[dim]; - // let ids_dim_size = ids_shape.elem_count(); - // let dst_el = ids_shape.elem_count() * left_size * right_size; - // let dtype = self.dtype; - // let device = self.device(); - // let buffer = device.new_buffer(dst_el, dtype); - // Ok(Self { - // buffer, - // device: device.clone(), - // dtype, - // }) + fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + assert!(src_l.is_contiguous()); + assert!(ids_l.is_contiguous()); + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let ids_el = ids_l.shape().elem_count(); + let dst_el = ids_el * left_size * right_size; + let dtype = self.dtype; + let device = self.device(); + let mut buffer = device.new_buffer(dst_el, dtype); + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "is_u32_f32", + (left, right) => todo!("index select metal {left:?} {right:?}"), + }; + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_index_select( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + &self.buffer, + &ids.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) } fn index_add( diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index eefaef34..c077cc48 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,39 +1,36 @@ #include using namespace metal; -kernel void is_u32_f32( - constant size_t &dst_size, - constant size_t &left_size, - constant size_t &src_dim_size, - constant size_t &right_size, - constant size_t &ids_size, - - const device float *input, - const device uint *input_ids, - device float *output, - - uint gid [[ thread_position_in_grid ]] -) { - - if (gid >= dst_size) { - return; - } - - const size_t id_i = gid / right_size / left_size; - const size_t right_rank_i = gid % right_size; - const size_t left_rank_i = gid % left_size; - - // Force prevent out of bounds indexing - // since there doesn't seem to be a good way to force crash - // No need to check for zero we're only allowing unsized. - const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1)); - const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; - - output[gid] = input[src_i]; - +# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint gid [[ thread_position_in_grid ]] \ +) { \ + if (gid >= dst_size) { \ + return; \ + } \ + const size_t id_i = gid / right_size / left_size; \ + const size_t right_rank_i = gid % right_size; \ + const size_t left_rank_i = gid % left_size; \ + /* \ + // Force prevent out of bounds indexing \ + // since there doesn't seem to be a good way to force crash \ + // No need to check for zero we're only allowing unsized. \ + */ \ + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ + const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ + output[gid] = input[src_i]; \ } + template void index_add( device I *ids [[buffer(0)]], @@ -82,6 +79,7 @@ kernel void FN_NAME( \ ) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ +INDEX_OP(is_u32_f32, uint, float) #if __METAL_VERSION__ >= 310 IA_OP(bfloat, int64_t, ia_i64_bf16) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1bcd56d1..6a01107c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -690,6 +690,63 @@ pub fn call_where_cond_strided( Ok(()) } +pub fn call_index_select( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: &Buffer, + ids: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let func = kernels.load_function(device, Source::Indexing, name)?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + input, + ids, + output + ) + ); + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); + let grid_size = MTLSize { + width: (dst_el as u64 + width - 1) / width, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -1003,61 +1060,32 @@ mod tests { dim: usize, ) -> Vec { let device = Device::system_default().expect("no device found"); - let options = CompileOptions::new(); - let library = device.new_library_with_source(INDEXING, &options).unwrap(); - - let left_size: usize = shape[..dim].iter().product(); - let right_size: usize = shape[dim + 1..].iter().product(); - let src_dim_size = shape[dim]; - let dst_el = ids.len() * left_size * right_size; - let ids_size = ids.len(); - - let function = library.get_function("is_u32_f32", None).unwrap(); - let pipeline = device - .new_compute_pipeline_state_with_function(&function) - .unwrap(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let encoder = command_buffer.new_compute_command_encoder(); - - encoder.set_compute_pipeline_state(&pipeline); - let embeddings_buffer = new_buffer(&device, &embeddings); let ids_buffer = new_buffer(&device, &ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - set_params!( - encoder, - ( - dst_el, - left_size, - src_dim_size, - right_size, - ids_size, - &embeddings_buffer, - &ids_buffer, - &mut dst_buffer - ) - ); + let kernels = Kernels::new(); + call_index_select( + &device, + &command_buffer, + &kernels, + "is_u32_f32", + shape, + ids.len(), + dim, + &embeddings_buffer, + &ids_buffer, + &mut dst_buffer, + ) + .unwrap(); - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); - let grid_size = MTLSize { - width: (dst_el as u64 + width - 1) / width, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - - println!("{width:?} - {:?}", grid_size); - - encoder.dispatch_thread_groups(grid_size, thread_group_size); - encoder.end_encoding(); command_buffer.commit(); command_buffer.wait_until_completed(); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 4dfc46c2..c6984474 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -18,45 +18,55 @@ METAL_FUNC uint get_strided_index( constant int THREADGROUP_SIZE = 256; -kernel void fast_sum_float( - constant size_t &src_numel, - constant size_t &el_to_sum_per_block, - device const float *src, - device float *dst, - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint blockDim [[ threads_per_threadgroup ]] -) { - - threadgroup float shared_memory[THREADGROUP_SIZE]; - - shared_memory[tid] = 0; - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); - shared_memory[tid] += src[idx]; - idx += blockDim; - } - - threadgroup_barrier(mem_flags::mem_none); - - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] += shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_none); - } - - dst[dst_id] = shared_memory[0]; -} +# define REDUCE(FN, NAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint blockDim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = 0; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + */ \ + TYPENAME x = shared_memory[tid]; \ + TYPENAME y = src[idx]; \ + shared_memory[tid] = FN; \ + idx += blockDim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = blockDim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + TYPENAME x = shared_memory[tid]; \ + TYPENAME y = shared_memory[tid + s]; \ + shared_memory[tid] = FN; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + dst[dst_id] = shared_memory[0]; \ +} \ kernel void softmax_float( constant size_t &src_numel, @@ -122,3 +132,8 @@ kernel void softmax_float( idx += blockDim; } } + + +REDUCE(x + y, fast_sum_float, float) +REDUCE(x * y, fast_mul_float, float) +REDUCE(max(x, y), fast_max_float, float)