From 581b104f972ca868a1c8a859e86efabbc37cec68 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 24 Jul 2023 20:22:47 +0100 Subject: [PATCH] Indexing cuda (#235) * Allow using uint8_t for indexing. * Revert the default cuda feature. * Add a cuda-kernel for index-select. * Add a test for gather. --- candle-core/src/cuda_backend.rs | 142 ++++++++++++++++++++++++++---- candle-core/tests/tensor_tests.rs | 39 +++++++- candle-kernels/src/embeddings.cu | 127 ++++++++++++++++++++++++-- 3 files changed, 277 insertions(+), 31 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index d4d8faaa..18d028ad 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -5,7 +5,8 @@ use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ - CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, + CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, + ValidAsZeroBits, }; use half::{bf16, f16}; use std::sync::{Arc, Mutex}; @@ -34,9 +35,6 @@ pub enum CudaError { #[error("internal error '{0}'")] InternalError(&'static str), - #[error("internal error '{0}'")] - WrappedError(Box), - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { lhs_stride: Vec, @@ -632,28 +630,28 @@ impl<'a> Map1 for Embedding<'a> { rhs_l: &Layout, ) -> Result> { let ids_l = &self.1; - let ids = match &self.0.slice { - CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + let (name, ids) = match &self.0.slice { + CudaStorageSlice::U32(slice) => { + ("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + CudaStorageSlice::U8(slice) => { + ("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) + } _ => Err(CudaError::UnexpectedDType { - msg: "embedding ids should be u32", + msg: "embedding ids should be u8 or u32", expected: DType::U32, got: self.0.dtype(), }) .w()?, }; - let ids = &ids; let shape = ids_l.shape(); - let (v_size, h_size) = rhs_l - .shape() - .dims2() - .map_err(|e| CudaError::WrappedError(Box::new(e))) - .w()?; + let (v_size, h_size) = rhs_l.shape().dims2()?; let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?; let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("emb"), kernels::EMBEDDINGS)?; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }.w()?; let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); @@ -663,6 +661,109 @@ impl<'a> Map1 for Embedding<'a> { } } +struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map1 for IndexSelect<'a> { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + src_l: &Layout, + ) -> Result> { + let ids_l = &self.1; + let (name, ids) = match &self.0.slice { + CudaStorageSlice::U32(slice) => { + ("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + CudaStorageSlice::U8(slice) => { + ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) + } + _ => Err(CudaError::UnexpectedDType { + msg: "index_select ids should be u8 or u32", + expected: DType::U32, + got: self.0.dtype(), + }) + .w()?, + }; + let ids_shape = ids_l.shape(); + let ids_dims = ids_shape.dims(); + let ids_el = ids_shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(ids_el as u32); + let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, + }; + let left_size: usize = src_l.dims()[..self.2].iter().product(); + let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); + let dim_size = src_l.dims()[self.2]; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(ids_el * left_size * right_size) }.w()?; + let params = ( + ids_el, + ids_dims.len(), + &ds, + ids, + &src, + &out, + left_size, + dim_size, + right_size, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + +struct Gather<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map1 for Gather<'a> { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + src_l: &Layout, + ) -> Result> { + let ids = &self.0; + let ids_l = &self.1; + let dim = self.2; + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let (name, ids) = match &ids.slice { + CudaStorageSlice::U32(slice) => { + ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) + } + CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + _ => Err(CudaError::UnexpectedDType { + msg: "gather ids should be u8 or u32", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let el = ids_l.shape().elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, + }; + let left_sz: usize = src_l.dims()[..dim].iter().product(); + let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_sz = src_l.dims()[dim]; + let ids_dim_sz = ids_l.dims()[dim]; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }.w()?; + let params = ( + el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { fn f( @@ -991,7 +1092,6 @@ impl BackendStorage for CudaStorage { } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { - use cudarc::driver::DevicePtr; let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); @@ -1169,11 +1269,15 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result { - Err(CudaError::InternalError("TODO: implement index-select").into()) + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result { + let device = self.device().clone(); + let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?; + Ok(Self { slice, device }) } - fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { - Err(CudaError::InternalError("TODO: implement gather").into()) + fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { + let device = self.device().clone(); + let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?; + Ok(Self { slice, device }) } fn scatter_add( &self, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 244bec58..501c55ec 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -316,10 +316,7 @@ fn cmp(device: &Device) -> Result<()> { Ok(()) } -#[test] -fn index_select() -> Result<()> { - // TODO: Test on cuda once the kernel is available. - let device = &Device::Cpu; +fn index_select(device: &Device) -> Result<()> { let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?; let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; assert_eq!( @@ -349,6 +346,38 @@ fn index_select() -> Result<()> { Ok(()) } +fn gather(device: &Device) -> Result<()> { + let ids = Tensor::new(&[[0u32], [2u32], [1u32], [0u32]], device)?; + let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?; + assert_eq!( + t.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0] + ] + ); + let hs = t.gather(&ids, 1)?; + assert_eq!(hs.to_vec2::()?, &[[0.0], [5.0], [7.0], [9.0]]); + let ids = Tensor::new( + &[[0u32, 0u32], [2u32, 0u32], [1u32, 1u32], [0u32, 2u32]], + device, + )?; + let hs = t.gather(&ids, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[[0.0, 0.0], [5.0, 3.0], [7.0, 7.0], [9.0, 11.0]] + ); + let ids = Tensor::new(&[[0u32, 2u32, 0u32]], device)?; + let hs = t.gather(&ids, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 7.0, 2.0]]); + let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?; + let hs = t.gather(&ids, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]); + Ok(()) +} + fn matmul(device: &Device) -> Result<()> { let data = vec![1.0f32, 2.0, 3.0, 4.0]; let a = Tensor::from_slice(&data, (2, 2), device)?; @@ -513,3 +542,5 @@ test_device!(embeddings, embeddings_cpu, embeddings_gpu); test_device!(cmp, cmp_cpu, cmp_gpu); test_device!(matmul, matmul_cpu, matmul_gpu); test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu); +test_device!(index_select, index_select_cpu, index_select_gpu); +test_device!(gather, gather_cpu, gather_gpu); diff --git a/candle-kernels/src/embeddings.cu b/candle-kernels/src/embeddings.cu index 8425b16b..335e7282 100644 --- a/candle-kernels/src/embeddings.cu +++ b/candle-kernels/src/embeddings.cu @@ -3,12 +3,12 @@ #include "cuda_utils.cuh" #include -#define EMB_OP(TYPENAME, FN_NAME) \ +#define EMB_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ const size_t *info, \ - const uint32_t *ids, \ + const INDEX_TYPENAME *ids, \ const TYPENAME *inp, \ TYPENAME *out, \ const size_t h_size, \ @@ -29,15 +29,126 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +template +__device__ void index_select( + const size_t numel, + const size_t num_dims, + const size_t *info, + const I *ids, + const T *inp, + T *out, + const size_t left_size, + const size_t dim_size, + const size_t right_size +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + for (unsigned int j = 0; j < left_size; ++j) { + memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[i]) * right_size], right_size * sizeof(T)); + } + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + for (unsigned int j = 0; j < left_size; ++j) { + memcpy(&out[(i + j * numel) * right_size], &inp[(j * dim_size + ids[strided_i]) * right_size], right_size * sizeof(T)); + } + } + } +} + +#define IS_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t dim_size, \ + const size_t right_size \ +) { index_select(numel, num_dims, info, ids, inp, out, left_size, dim_size, right_size); } \ + +template +__device__ void gather( + const size_t numel, + const I *ids, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t ids_dim_size, + const size_t right_size +) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + size_t post = i % right_size; + size_t idx = ids[i]; + size_t pre = i / (right_size * ids_dim_size); + size_t src_i = (pre * src_dim_size + idx) * right_size + post; + out[i] = inp[src_i]; + } +} + +#define GATHER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t ids_dim_size, \ + const size_t right_size \ +) { gather(numel, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \ + #if __CUDA_ARCH__ >= 800 -EMB_OP(__nv_bfloat16, emb_bf16) +EMB_OP(__nv_bfloat16, uint32_t, emb_u32_bf16) +EMB_OP(__nv_bfloat16, uint8_t, emb_u8_bf16) +IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) +IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) +GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 -EMB_OP(__half, emb_f16) +EMB_OP(__half, uint32_t, emb_u32_f16) +EMB_OP(__half, uint8_t, emb_u8_f16) +IS_OP(__half, uint32_t, is_u32_f16) +IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, uint32_t, gather_u32_f16) +GATHER_OP(__half, uint8_t, gather_u8_f16) #endif -EMB_OP(float, emb_f32) -EMB_OP(double, emb_f64) -EMB_OP(uint8_t, emb_u8) -EMB_OP(uint32_t, emb_u32) +EMB_OP(float, uint32_t, emb_u32_f32) +EMB_OP(double, uint32_t, emb_u32_f64) +EMB_OP(uint8_t, uint32_t, emb_u32_u8) +EMB_OP(uint32_t, uint32_t, emb_u32_u32) + +EMB_OP(float, uint8_t, emb_u8_f32) +EMB_OP(double, uint8_t, emb_u8_f64) +EMB_OP(uint8_t, uint8_t, emb_u8_u8) +EMB_OP(uint32_t, uint8_t, emb_u8_u32) + +IS_OP(float, uint32_t, is_u32_f32) +IS_OP(double, uint32_t, is_u32_f64) +IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(uint32_t, uint32_t, is_u32_u32) + +IS_OP(float, uint8_t, is_u8_f32) +IS_OP(double, uint8_t, is_u8_f64) +IS_OP(uint8_t, uint8_t, is_u8_u8) +IS_OP(uint32_t, uint8_t, is_u8_u32) + +GATHER_OP(float, uint32_t, gather_u32_f32) +GATHER_OP(double, uint32_t, gather_u32_f64) +GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(uint32_t, uint32_t, gather_u32_u32) + +GATHER_OP(float, uint8_t, gather_u8_f32) +GATHER_OP(double, uint8_t, gather_u8_f64) +GATHER_OP(uint8_t, uint8_t, gather_u8_u8) +GATHER_OP(uint32_t, uint8_t, gather_u8_u32)