From 33c0234a330a3f95f831d7d3c460b4be0c1b0888 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 13:25:56 +0100 Subject: [PATCH] (Properly) add the where kernels. --- src/cuda_backend.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 53fba634..2c96cc6b 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -481,7 +481,7 @@ impl CudaStorage { let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?; let slice = match (&t.slice, &f.slice) { (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { - let func = dev.get_or_load_func("where_f32", kernels::BINARY)?; + let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; let params = (el, dims.len(), &ds, ids, t, f, &out); @@ -491,7 +491,7 @@ impl CudaStorage { } (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_f64", kernels::BINARY)?; + let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?; let out = unsafe { dev.alloc::(el) }?; let params = (el, dims.len(), &ds, ids, t, f, &out); // SAFETY: ffi @@ -500,7 +500,7 @@ impl CudaStorage { } (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_u32", kernels::BINARY)?; + let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?; let out = unsafe { dev.alloc::(el) }?; let params = (el, dims.len(), &ds, ids, t, f, &out); // SAFETY: ffi