mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
(Properly) add the where kernels.
This commit is contained in:
@ -481,7 +481,7 @@ impl CudaStorage {
|
||||
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
||||
let slice = match (&t.slice, &f.slice) {
|
||||
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
||||
let func = dev.get_or_load_func("where_f32", kernels::BINARY)?;
|
||||
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
@ -491,7 +491,7 @@ impl CudaStorage {
|
||||
}
|
||||
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func("where_f64", kernels::BINARY)?;
|
||||
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
|
||||
let out = unsafe { dev.alloc::<f64>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
@ -500,7 +500,7 @@ impl CudaStorage {
|
||||
}
|
||||
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func("where_u32", kernels::BINARY)?;
|
||||
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
|
||||
let out = unsafe { dev.alloc::<u32>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
|
Reference in New Issue
Block a user