diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 4050b595..a88d62c7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -940,16 +940,22 @@ impl<'a> Map2 for WhereCond<'a> { dev: &CudaDevice, ) -> Result> { let ids_l = &self.1; - let ids = match &self.0.slice { - CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + let (ids, name) = match &self.0.slice { + CudaStorageSlice::U8(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_u8") + } + CudaStorageSlice::U32(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_u32") + } _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u32", + msg: "where conditions should be u8 or u32", expected: DType::U32, got: self.0.dtype(), }) .w()?, }; - let ids = &ids; let shape = ids_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); @@ -959,7 +965,7 @@ impl<'a> Map2 for WhereCond<'a> { .w()?; let t = &t.slice(layout_t.start_offset()..); let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func(&kernel_name::("where"), kernels::TERNARY)?; + let func = dev.get_or_load_func(&kernel_name::(name), kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, ids, t, f, &out); diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index c064f6e5..eceb45c8 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -1,12 +1,12 @@ #include "cuda_utils.cuh" #include -#define WHERE_OP(TYPENAME, FN_NAME) \ +#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ const size_t *info, \ - const uint32_t *ids, \ + const ID_TYPENAME *ids, \ const TYPENAME *t, \ const TYPENAME *f, \ TYPENAME *out \ @@ -33,14 +33,21 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 -WHERE_OP(__nv_bfloat16, where_bf16) +WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) +WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 -WHERE_OP(__half, where_f16) +WHERE_OP(__half, uint32_t, where_u32_f16) +WHERE_OP(__half, uint8_t, where_u8_f16) #endif -WHERE_OP(float, where_f32) -WHERE_OP(double, where_f64) -WHERE_OP(uint8_t, where_u8) -WHERE_OP(uint32_t, where_u32) +WHERE_OP(float, uint32_t, where_u32_f32) +WHERE_OP(double, uint32_t, where_u32_f64) +WHERE_OP(uint8_t, uint32_t, where_u32_u8) +WHERE_OP(uint32_t, uint32_t, where_u32_u32) + +WHERE_OP(float, uint8_t, where_u8_f32) +WHERE_OP(double, uint8_t, where_u8_f64) +WHERE_OP(uint8_t, uint8_t, where_u8_u8) +WHERE_OP(uint8_t, uint32_t, where_u8_u32)