From 74a6a769ddadede96ec4495cc5edd316e0827150 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 24 Jul 2023 21:53:08 +0100 Subject: [PATCH] Cuda kernels for IndexAdd/ScatterAdd. (#236) * Skeleton methods for IndexAdd/ScatterAdd. * Add a Map2InPlace trait. * Add the glue code for the index-add/scatter-add kernels. * Tweak the file name: embeddings -> indexing. * Add the cuda kernel for indexadd. * And add the scatter-add kernels. --- candle-core/src/cuda_backend.rs | 182 +++++++++++++++--- .../src/{embeddings.cu => indexing.cu} | 101 ++++++++++ candle-kernels/src/lib.rs | 2 +- 3 files changed, 255 insertions(+), 30 deletions(-) rename candle-kernels/src/{embeddings.cu => indexing.cu} (60%) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 18d028ad..c550d982 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -398,12 +398,42 @@ trait Map2 { (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), - _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) } } +trait Map2InPlace { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()>; + + fn map( + &self, + dst: &mut S, + dst_s: &Shape, + src: &S, + src_l: &Layout, + d: &CudaDevice, + ) -> Result<()> { + match (dst, src) { + (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), + (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, + } + } +} + trait Map2Any { fn f( &self, @@ -651,7 +681,7 @@ impl<'a> Map1 for Embedding<'a> { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?; let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::(name), kernels::EMBEDDINGS)?; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }.w()?; let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); @@ -696,7 +726,7 @@ impl<'a> Map1 for IndexSelect<'a> { let left_size: usize = src_l.dims()[..self.2].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); let dim_size = src_l.dims()[self.2]; - let func = dev.get_or_load_func(&kernel_name::(name), kernels::EMBEDDINGS)?; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(ids_el * left_size * right_size) }.w()?; let params = ( @@ -752,7 +782,7 @@ impl<'a> Map1 for Gather<'a> { let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; let ids_dim_sz = ids_l.dims()[dim]; - let func = dev.get_or_load_func(&kernel_name::(name), kernels::EMBEDDINGS)?; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = ( @@ -764,6 +794,97 @@ impl<'a> Map1 for Gather<'a> { } } +struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map2InPlace for IndexAdd<'a> { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()> { + let ids = &self.0; + let ids_l = &self.1; + let dim = self.2; + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let (name, ids) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + _ => Err(CudaError::UnexpectedDType { + msg: "index-add ids should be u8 or u32", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + let left_sz: usize = src_l.dims()[..dim].iter().product(); + let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_sz = src_l.dims()[dim]; + let dst_dim_sz = dst_shape.dims()[dim]; + let ids_dim_sz = ids_l.dims()[0]; + let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + // SAFETY: Set later by running the kernel. + let params = ( + ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) + } +} + +struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize); +impl<'a> Map2InPlace for ScatterAdd<'a> { + fn f( + &self, + dst: &mut CudaSlice, + _dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()> { + let ids = &self.0; + let ids_l = &self.1; + let dim = self.2; + let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { + Some(o12) => o12, + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let (name, ids) = match &ids.slice { + CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + _ => Err(CudaError::UnexpectedDType { + msg: "scatter-add ids should be u8 or u32", + expected: DType::U32, + got: ids.dtype(), + })?, + }; + let src = match src_l.contiguous_offsets() { + Some((o1, o2)) => src.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, + }; + let left_sz: usize = src_l.dims()[..dim].iter().product(); + let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_sz = src_l.dims()[dim]; + let ids_dim_sz = ids_l.dims()[dim]; + let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); + let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; + // SAFETY: Set later by running the kernel. + let params = (ids, &src, dst, left_sz, src_dim_sz, ids_dim_sz, right_sz); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) + } +} + struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { fn f( @@ -1004,8 +1125,7 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - }) - .w()? + })? }; // The b tensor has dims batching, m, k (lhs) let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { @@ -1017,8 +1137,7 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - }) - .w()? + })? }; // The setup below was copied from: // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 @@ -1043,8 +1162,7 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - }) - .w()?, + })?, }; let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, @@ -1054,8 +1172,7 @@ fn gemm_config( lhs_stride: lhs_stride.to_vec(), rhs_stride: rhs_stride.to_vec(), mnk: (m, n, k), - }) - .w()?, + })?, }; Ok(StridedBatchedConfig { @@ -1281,25 +1398,33 @@ impl BackendStorage for CudaStorage { } fn scatter_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result { - Err(CudaError::InternalError("TODO: implement scatter-add").into()) + let device = self.device().clone(); + let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; + Ok(acc) } fn index_add( &self, - _: &Layout, - _: &Self, - _: &Layout, - _: &Self, - _: &Layout, - _: usize, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, ) -> Result { - Err(CudaError::InternalError("TODO: implement index-add").into()) + let device = self.device().clone(); + let mut acc = device.zeros_impl(l.shape(), self.dtype())?; + self.copy_strided_src(&mut acc, 0, l)?; + IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; + Ok(acc) } fn matmul( @@ -1364,7 +1489,7 @@ impl BackendStorage for CudaStorage { .w()?; CudaStorageSlice::F64(out) } - _ => Err(CudaError::InternalError("dtype mismatch in matmul op")).w()?, + _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, }; let device = dev.clone(); Ok(Self { slice, device }) @@ -1452,8 +1577,7 @@ impl BackendStorage for CudaStorage { } _ => Err(CudaError::InternalError( "dtype mismatch in copy_strided op", - )) - .w()?, + ))?, } Ok(()) } diff --git a/candle-kernels/src/embeddings.cu b/candle-kernels/src/indexing.cu similarity index 60% rename from candle-kernels/src/embeddings.cu rename to candle-kernels/src/indexing.cu index 335e7282..fb2d56b2 100644 --- a/candle-kernels/src/embeddings.cu +++ b/candle-kernels/src/indexing.cu @@ -105,6 +105,79 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { gather(numel, ids, inp, out, left_size, src_dim_size, ids_dim_size, right_size); } \ +template +__device__ void index_add( + const I *ids, + const size_t ids_dim_size, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t idx = ids[j]; + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } + } +} + +#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + +template +__device__ void scatter_add( + const I *ids, + const size_t ids_dim_size, + const T *inp, + T *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } + } +} + +#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + + #if __CUDA_ARCH__ >= 800 EMB_OP(__nv_bfloat16, uint32_t, emb_u32_bf16) EMB_OP(__nv_bfloat16, uint8_t, emb_u8_bf16) @@ -112,6 +185,10 @@ IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) +IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) +SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -121,6 +198,10 @@ IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, uint32_t, ia_u32_f16) +IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, uint32_t, sa_u32_f16) +SA_OP(__half, uint8_t, sa_u8_f16) #endif EMB_OP(float, uint32_t, emb_u32_f32) @@ -152,3 +233,23 @@ GATHER_OP(float, uint8_t, gather_u8_f32) GATHER_OP(double, uint8_t, gather_u8_f64) GATHER_OP(uint8_t, uint8_t, gather_u8_u8) GATHER_OP(uint32_t, uint8_t, gather_u8_u32) + +IA_OP(float, uint32_t, ia_u32_f32) +IA_OP(double, uint32_t, ia_u32_f64) +IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(uint32_t, uint32_t, ia_u32_u32) + +IA_OP(float, uint8_t, ia_u8_f32) +IA_OP(double, uint8_t, ia_u8_f64) +IA_OP(uint8_t, uint8_t, ia_u8_u8) +IA_OP(uint32_t, uint8_t, ia_u8_u32) + +SA_OP(float, uint32_t, sa_u32_f32) +SA_OP(double, uint32_t, sa_u32_f64) +SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(uint32_t, uint32_t, sa_u32_u32) + +SA_OP(float, uint8_t, sa_u8_f32) +SA_OP(double, uint8_t, sa_u8_f64) +SA_OP(uint8_t, uint8_t, sa_u8_u8) +SA_OP(uint32_t, uint8_t, sa_u8_u32) diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index b9d12b7b..478debd3 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -2,8 +2,8 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); -pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));