Add a cuda kernel for embeddings.

This commit is contained in:
laurent
2023-06-26 11:47:57 +01:00
parent 5952c3fa91
commit 16f0f5b9d2
5 changed files with 101 additions and 11 deletions

34
kernels/src/embeddings.cu Normal file
View File

@ -0,0 +1,34 @@
// WARNING: THIS IS ONLY VALID ASSUMING THAT inp IS CONTIGUOUS!
// TODO: proper error reporting when ids are larger than v_size.
#include "cuda_utils.cuh"
#include<stdint.h>
#define EMB_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *info, \
const uint32_t *ids, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t h_size, \
const size_t v_size \
) { \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
memcpy(out + i * h_size, inp + ids[i], h_size); \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
memcpy(out + i * h_size, inp + ids[i], h_size); \
} \
} \
} \
EMB_OP(float, emb_f32)
EMB_OP(double, emb_f64)
EMB_OP(uint32_t, emb_u32)

View File

@ -1,4 +1,5 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

View File

@ -24,6 +24,13 @@ pub enum CudaError {
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
}
type Result<T> = std::result::Result<T, CudaError>;
@ -349,9 +356,6 @@ impl CudaStorage {
rhs_stride: &[usize],
) -> Result<Self> {
let dims = shape.dims();
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
return Err(CudaError::InternalError("TODO: implement broadcast"));
}
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device();
@ -425,13 +429,57 @@ impl CudaStorage {
pub(crate) fn embedding_impl(
&self,
_shape: &Shape,
_stride: &[usize],
_rhs: &Self,
_hidden_size: usize,
_vocab_size: usize,
shape: &Shape,
stride: &[usize],
rhs: &Self,
h_size: usize, // hidden size
v_size: usize, // vocab size
) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement embedding"))
let ids = match &self.slice {
CudaStorageSlice::U32(slice) => slice,
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, stride].concat())?;
let slice = match &rhs.slice {
// The kernels below assume that rhs is contiguous.
CudaStorageSlice::U32(arg) => {
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn matmul_impl(

View File

@ -46,6 +46,7 @@ macro_rules! with_dtype {
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
@ -56,6 +57,7 @@ macro_rules! with_dtype {
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
@ -66,6 +68,7 @@ macro_rules! with_dtype {
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}

View File

@ -3,8 +3,12 @@ use crate::{DType, DeviceLocation, Shape};
/// Main library error type.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
UnexpectedDType { expected: DType, got: DType },
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },