mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add a cuda kernel for embeddings.
This commit is contained in:
34
kernels/src/embeddings.cu
Normal file
34
kernels/src/embeddings.cu
Normal file
@ -0,0 +1,34 @@
|
||||
// WARNING: THIS IS ONLY VALID ASSUMING THAT inp IS CONTIGUOUS!
|
||||
// TODO: proper error reporting when ids are larger than v_size.
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define EMB_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const uint32_t *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t h_size, \
|
||||
const size_t v_size \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
memcpy(out + i * h_size, inp + ids[i], h_size); \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
memcpy(out + i * h_size, inp + ids[i], h_size); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
EMB_OP(float, emb_f32)
|
||||
EMB_OP(double, emb_f64)
|
||||
EMB_OP(uint32_t, emb_u32)
|
@ -1,4 +1,5 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
@ -24,6 +24,13 @@ pub enum CudaError {
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, CudaError>;
|
||||
@ -349,9 +356,6 @@ impl CudaStorage {
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let dims = shape.dims();
|
||||
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
|
||||
return Err(CudaError::InternalError("TODO: implement broadcast"));
|
||||
}
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dev = self.device();
|
||||
@ -425,13 +429,57 @@ impl CudaStorage {
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
_shape: &Shape,
|
||||
_stride: &[usize],
|
||||
_rhs: &Self,
|
||||
_hidden_size: usize,
|
||||
_vocab_size: usize,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
rhs: &Self,
|
||||
h_size: usize, // hidden size
|
||||
v_size: usize, // vocab size
|
||||
) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement embedding"))
|
||||
let ids = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => slice,
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "embedding ids should be u32",
|
||||
expected: DType::U32,
|
||||
got: self.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let slice = match &rhs.slice {
|
||||
// The kernels below assume that rhs is contiguous.
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
|
@ -46,6 +46,7 @@ macro_rules! with_dtype {
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
}
|
||||
@ -56,6 +57,7 @@ macro_rules! with_dtype {
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
}
|
||||
@ -66,6 +68,7 @@ macro_rules! with_dtype {
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -3,8 +3,12 @@ use crate::{DType, DeviceLocation, Shape};
|
||||
/// Main library error type.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("unexpected dtype, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType { expected: DType, got: DType },
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
Reference in New Issue
Block a user