From 5e1c595e00721a11bb46e9187ea7d86ea4ace0e3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 28 Sep 2023 09:05:29 +0100 Subject: [PATCH] Optimize the index-select cuda kernel. (#976) --- candle-core/src/cuda_backend.rs | 8 ++++---- candle-kernels/src/indexing.cu | 22 ++++++++-------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 00fd1d04..d1187b1c 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -884,8 +884,6 @@ impl<'a> Map1 for IndexSelect<'a> { }; let ids_shape = ids_l.shape(); let ids_dims = ids_shape.dims(); - let ids_el = ids_shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(ids_el as u32); let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), @@ -894,11 +892,13 @@ impl<'a> Map1 for IndexSelect<'a> { let left_size: usize = src_l.dims()[..self.2].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); let dim_size = src_l.dims()[self.2]; + let dst_el = ids_shape.elem_count() * left_size * right_size; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(ids_el * left_size * right_size) }.w()?; + let out = unsafe { dev.alloc::(dst_el) }.w()?; let params = ( - ids_el, + dst_el, ids_dims.len(), &ds, ids, diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index c57be129..0272a330 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -17,20 +17,14 @@ __device__ void index_select( ) { const size_t *dims = info; const size_t *strides = info + num_dims; - if (is_contiguous(num_dims, dims, strides)) { - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { - for (unsigned int j = 0; j < left_size; ++j) { - memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[i]) * right_size], right_size * sizeof(T)); - } - } - } - else { - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { - unsigned strided_i = get_strided_index(i, num_dims, dims, strides); - for (unsigned int j = 0; j < left_size; ++j) { - memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[strided_i]) * right_size], right_size * sizeof(T)); - } - } + bool b = is_contiguous(num_dims, dims, strides); + for (unsigned int dst_i = blockIdx.x * blockDim.x + threadIdx.x; dst_i < numel; dst_i += blockDim.x * gridDim.x) { + unsigned int left_i = dst_i / (dim_size * right_size); + unsigned int id_i = dst_i / right_size % dim_size; + unsigned int right_i = dst_i % right_size; + unsigned int src_i = left_i * (dim_size * right_size) + ids[id_i] * right_size + right_i; + unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides); + out[dst_i] = inp[strided_i]; } }