mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cuda support for embedding f16.
This commit is contained in:
@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
EMB_OP(__half, emb_f16)
|
||||||
|
#endif
|
||||||
|
|
||||||
EMB_OP(float, emb_f32)
|
EMB_OP(float, emb_f32)
|
||||||
EMB_OP(double, emb_f64)
|
EMB_OP(double, emb_f64)
|
||||||
EMB_OP(uint32_t, emb_u32)
|
EMB_OP(uint32_t, emb_u32)
|
||||||
|
@ -726,7 +726,7 @@ impl CudaStorage {
|
|||||||
let slice = match &rhs.slice {
|
let slice = match &rhs.slice {
|
||||||
// The kernels below assume that rhs is contiguous.
|
// The kernels below assume that rhs is contiguous.
|
||||||
CudaStorageSlice::U32(arg) => {
|
CudaStorageSlice::U32(arg) => {
|
||||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||||
|
Reference in New Issue
Block a user