From 16f0f5b9d2b9a2cda42d1754783a68326aa8b1ba Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 11:47:57 +0100 Subject: [PATCH] Add a cuda kernel for embeddings. --- kernels/src/embeddings.cu | 34 ++++++++++++++++++++ kernels/src/lib.rs | 1 + src/cuda_backend.rs | 66 +++++++++++++++++++++++++++++++++------ src/dtype.rs | 3 ++ src/error.rs | 8 +++-- 5 files changed, 101 insertions(+), 11 deletions(-) create mode 100644 kernels/src/embeddings.cu diff --git a/kernels/src/embeddings.cu b/kernels/src/embeddings.cu new file mode 100644 index 00000000..79bd85a4 --- /dev/null +++ b/kernels/src/embeddings.cu @@ -0,0 +1,34 @@ +// WARNING: THIS IS ONLY VALID ASSUMING THAT inp IS CONTIGUOUS! +// TODO: proper error reporting when ids are larger than v_size. +#include "cuda_utils.cuh" +#include + +#define EMB_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const uint32_t *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t h_size, \ + const size_t v_size \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + memcpy(out + i * h_size, inp + ids[i], h_size); \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + memcpy(out + i * h_size, inp + ids[i], h_size); \ + } \ + } \ +} \ + +EMB_OP(float, emb_f32) +EMB_OP(double, emb_f64) +EMB_OP(uint32_t, emb_u32) diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index 8e0d9eb9..df9c2f85 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -1,4 +1,5 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); +pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 51eb3ac2..710227d0 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -24,6 +24,13 @@ pub enum CudaError { #[error("internal error '{0}'")] InternalError(&'static str), + + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, } type Result = std::result::Result; @@ -349,9 +356,6 @@ impl CudaStorage { rhs_stride: &[usize], ) -> Result { let dims = shape.dims(); - if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() { - return Err(CudaError::InternalError("TODO: implement broadcast")); - } let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = self.device(); @@ -425,13 +429,57 @@ impl CudaStorage { pub(crate) fn embedding_impl( &self, - _shape: &Shape, - _stride: &[usize], - _rhs: &Self, - _hidden_size: usize, - _vocab_size: usize, + shape: &Shape, + stride: &[usize], + rhs: &Self, + h_size: usize, // hidden size + v_size: usize, // vocab size ) -> Result { - Err(CudaError::InternalError("TODO: implement embedding")) + let ids = match &self.slice { + CudaStorageSlice::U32(slice) => slice, + _ => Err(CudaError::UnexpectedDType { + msg: "embedding ids should be u32", + expected: DType::U32, + got: self.dtype(), + })?, + }; + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let dev = self.device(); + let ds = dev.htod_copy([dims, stride].concat())?; + let slice = match &rhs.slice { + // The kernels below assume that rhs is contiguous. + CudaStorageSlice::U32(arg) => { + let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U32(out) + } + CudaStorageSlice::F32(arg) => { + let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F32(out) + } + CudaStorageSlice::F64(arg) => { + let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F64(out) + } + }; + let device = dev.clone(); + Ok(Self { slice, device }) } pub(crate) fn matmul_impl( diff --git a/src/dtype.rs b/src/dtype.rs index a11ab350..471f415c 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -46,6 +46,7 @@ macro_rules! with_dtype { _ => Err(Error::UnexpectedDType { expected: DType::$dtype, got: s.dtype(), + msg: "unexpected dtype", }), } } @@ -56,6 +57,7 @@ macro_rules! with_dtype { _ => Err(Error::UnexpectedDType { expected: DType::$dtype, got: s.dtype(), + msg: "unexpected dtype", }), } } @@ -66,6 +68,7 @@ macro_rules! with_dtype { _ => Err(Error::UnexpectedDType { expected: DType::$dtype, got: s.dtype(), + msg: "unexpected dtype", }), } } diff --git a/src/error.rs b/src/error.rs index 31bcf4c3..83d3e66d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,8 +3,12 @@ use crate::{DType, DeviceLocation, Shape}; /// Main library error type. #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("unexpected dtype, expected: {expected:?}, got: {got:?}")] - UnexpectedDType { expected: DType, got: DType }, + #[error("{msg}, expected: {expected:?}, got: {got:?}")] + UnexpectedDType { + msg: &'static str, + expected: DType, + got: DType, + }, #[error("{op} only supports contiguous tensors")] RequiresContiguous { op: &'static str },