(Properly) add the where kernels.

This commit is contained in:
laurent
2023-06-26 13:25:56 +01:00
parent cd2a171c06
commit 33c0234a33

View File

@ -481,7 +481,7 @@ impl CudaStorage {
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
let slice = match (&t.slice, &f.slice) {
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
let func = dev.get_or_load_func("where_f32", kernels::BINARY)?;
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
@ -491,7 +491,7 @@ impl CudaStorage {
}
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_f64", kernels::BINARY)?;
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<f64>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
@ -500,7 +500,7 @@ impl CudaStorage {
}
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_u32", kernels::BINARY)?;
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<u32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi