Cuda support for embedding f16.

This commit is contained in:
laurent
2023-06-26 21:58:15 +01:00
parent becb822ce0
commit d204f1c7c0
2 changed files with 5 additions and 1 deletions

View File

@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \
} \ } \
} \ } \
#if __CUDA_ARCH__ >= 530
EMB_OP(__half, emb_f16)
#endif
EMB_OP(float, emb_f32) EMB_OP(float, emb_f32)
EMB_OP(double, emb_f64) EMB_OP(double, emb_f64)
EMB_OP(uint32_t, emb_u32) EMB_OP(uint32_t, emb_u32)

View File

@ -726,7 +726,7 @@ impl CudaStorage {
let slice = match &rhs.slice { let slice = match &rhs.slice {
// The kernels below assume that rhs is contiguous. // The kernels below assume that rhs is contiguous.
CudaStorageSlice::U32(arg) => { CudaStorageSlice::U32(arg) => {
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel. // SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el * h_size) }?; let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);