From fee33b45c2b635d83fa2ca0955ae453fe26374ea Mon Sep 17 00:00:00 2001 From: Thomas Santerre Date: Fri, 22 Mar 2024 02:30:02 -0400 Subject: [PATCH] Add support for strided index-select on Metal (#1909) * initial implementation * use correct index, but still not breaking like it should have... * fix test --- candle-core/src/metal_backend.rs | 18 +++--- candle-metal-kernels/src/indexing.metal | 41 ++++++++++--- candle-metal-kernels/src/lib.rs | 12 +++- candle-metal-kernels/src/tests.rs | 81 +++++++++++++++++++++++-- 4 files changed, 129 insertions(+), 23 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index ef044fc8..73a141ea 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -2,9 +2,8 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; +use candle_metal_kernels::CallConvTranspose2dCfg; use candle_metal_kernels::Kernels; -use candle_metal_kernels::{self, CallConvTranspose2dCfg}; -use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; @@ -1348,12 +1347,8 @@ impl BackendStorage for MetalStorage { } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - if !(src_l.is_contiguous() - && src_l.start_offset() == 0 - && ids_l.is_contiguous() - && ids_l.start_offset() == 0) - { - crate::bail!("Metal strided index_select not implemented"); + if !ids_l.is_contiguous() { + crate::bail!("Metal index_select requires contiguous ids") } let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); @@ -1364,6 +1359,8 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U8, DType::F32) => "is_u8_f32", + (DType::U8, DType::F16) => "is_u8_f16", (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", @@ -1382,8 +1379,13 @@ impl BackendStorage for MetalStorage { src_l.dims(), ids_el, dim, + src_l.is_contiguous(), + src_l.dims(), + src_l.stride(), &self.buffer, + src_l.start_offset() * dtype.size_in_bytes(), &ids.buffer, + ids_l.start_offset() * ids.dtype.size_in_bytes(), &buffer, ) .map_err(MetalError::from)?; diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 65491759..ad4a8605 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,20 +1,38 @@ #include using namespace metal; +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + template METAL_FUNC void index( constant size_t &dst_size, constant size_t &left_size, constant size_t &src_dim_size, constant size_t &right_size, - constant size_t &ids_size, - const device TYPENAME *input, + constant size_t &ids_size, + constant bool &contiguous, + constant size_t *src_dims, + constant size_t *src_strides, + const device TYPENAME *input, const device INDEX_TYPENAME *input_ids, device TYPENAME *output, uint tid [[ thread_position_in_grid ]] ) { if (tid >= dst_size) { - return; + return; } const size_t id_i = (tid / right_size) % ids_size; const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); @@ -26,7 +44,8 @@ METAL_FUNC void index( // No need to check for zero we're only allowing unsized. */ const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; - output[tid] = input[src_i]; + const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); + output[tid] = input[strided_src_i]; } # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ @@ -36,12 +55,15 @@ kernel void NAME( \ constant size_t &src_dim_size, \ constant size_t &right_size, \ constant size_t &ids_size, \ + constant bool &contiguous, \ + constant size_t *src_dims, \ + constant size_t *src_strides, \ const device TYPENAME *input, \ const device INDEX_TYPENAME *input_ids, \ device TYPENAME *output, \ uint tid [[ thread_position_in_grid ]] \ ) { \ - index(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \ + index(dst_size, left_size, src_dim_size, right_size, ids_size, contiguous, src_dims, src_strides, input, input_ids, output, tid); \ } @@ -165,10 +187,15 @@ kernel void NAME( \ } -INDEX_OP(is_u32_f32, uint, float) -INDEX_OP(is_u32_f16, uint, half) +INDEX_OP(is_u32_f32, uint32_t, float) +INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) INDEX_OP(is_u32_bf16, uint32_t, bfloat) +#endif + +INDEX_OP(is_u8_f32, uint8_t, float) +INDEX_OP(is_u8_f16, uint8_t, half) +#if defined(__HAVE_BFLOAT__) INDEX_OP(is_u8_bf16, uint8_t, bfloat) #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f2c9c7fe..e17365a0 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1067,8 +1067,13 @@ pub fn call_index_select( shape: &[usize], ids_size: usize, dim: usize, + contiguous: bool, + src_dims: &[usize], + src_strides: &[usize], input: &Buffer, + src_offset: usize, ids: &Buffer, + ids_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); @@ -1090,8 +1095,11 @@ pub fn call_index_select( src_dim_size, right_size, ids_size, - input, - ids, + contiguous, + src_dims, + src_strides, + (input, src_offset), + (ids, ids_offset), output ) ); diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 5045a4a3..b15d9b36 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -600,22 +600,35 @@ fn affine_strided() { fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [2, 5]; + let stride = [1, 2]; let ids = [0u32, 1, 0]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); } +#[test] +fn index_select_strided() { + let embedding = (0..16).map(|x| x as f32).collect::>(); + let shape = [2, 2]; + let stride = [2, 4]; + let ids = [0u32]; + let dim = 0; + let result = run_index_select_strided(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); + assert_eq!(result, vec![0.0, 4.0]); +} + #[test] fn index_select_f16() { let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] @@ -623,9 +636,10 @@ fn index_select_f16() { .map(|x| f16::from_f32(x)) .collect(); let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f16"); assert_eq!( approx_f16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -636,9 +650,10 @@ fn index_select_f16() { fn index_select_is_u32_bf16() { let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_bf16"); assert_eq!( approx_bf16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -649,9 +664,10 @@ fn index_select_is_u32_bf16() { fn index_select_is_u8_bf16() { let embedding: Vec = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u8, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u8_bf16"); assert_eq!( approx_bf16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -662,9 +678,10 @@ fn index_select_is_u8_bf16() { fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; + let stride = [2, 1]; let ids = [0u32, 1, 0]; let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); + let result = run_index_select(&embedding, &shape, &stride, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] @@ -674,6 +691,7 @@ fn index_select_dim1() { fn run_index_select( embeddings: &[T], shape: &[usize], + stride: &[usize], ids: &[I], dim: usize, name: &'static str, @@ -699,8 +717,59 @@ fn run_index_select( shape, ids.len(), dim, + true, + shape, + stride, &embeddings_buffer, + 0, &ids_buffer, + 0, + &dst_buffer, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&dst_buffer, dst_el) +} + +fn run_index_select_strided( + embeddings: &[T], + shape: &[usize], + stride: &[usize], + ids: &[I], + dim: usize, + name: &'static str, +) -> Vec { + let device = Device::system_default().expect("no device found"); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let embeddings_buffer = new_buffer(&device, &embeddings); + let ids_buffer = new_buffer(&device, &ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let kernels = Kernels::new(); + call_index_select( + &device, + &command_buffer, + &kernels, + name, + shape, + ids.len(), + dim, + false, + shape, + stride, + &embeddings_buffer, + 0, + &ids_buffer, + 0, &dst_buffer, ) .unwrap();