diff --git a/kernels/src/embeddings.cu b/kernels/src/embeddings.cu index 79bd85a4..1dd12cf1 100644 --- a/kernels/src/embeddings.cu +++ b/kernels/src/embeddings.cu @@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \ } \ } \ +#if __CUDA_ARCH__ >= 530 +EMB_OP(__half, emb_f16) +#endif + EMB_OP(float, emb_f32) EMB_OP(double, emb_f64) EMB_OP(uint32_t, emb_u32) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 90cb0f72..d5be8bf6 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -726,7 +726,7 @@ impl CudaStorage { let slice = match &rhs.slice { // The kernels below assume that rhs is contiguous. CudaStorageSlice::U32(arg) => { - let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; + let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);