From b1d6e264da13dc5dca615de3ae2e7fb0a1518fff Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 26 Jun 2023 13:11:14 +0100 Subject: [PATCH] Sketch the where_cond cuda kernel wrapper. --- src/cuda_backend.rs | 59 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 7 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 7bdd3e03..53fba634 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -459,14 +459,59 @@ impl CudaStorage { pub(crate) fn where_cond( &self, - _shape: &Shape, - _stride: &[usize], - _t: &Self, - _stride_t: &[usize], - _f: &Self, - _stride_f: &[usize], + shape: &Shape, + stride: &[usize], + t: &Self, + stride_t: &[usize], + f: &Self, + stride_f: &[usize], ) -> Result { - Err(CudaError::InternalError("TODO: implement where_cond")) + let ids = match &self.slice { + CudaStorageSlice::U32(slice) => slice, + _ => Err(CudaError::UnexpectedDType { + msg: "embedding ids should be u32", + expected: DType::U32, + got: self.dtype(), + })?, + }; + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let dev = self.device(); + let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?; + let slice = match (&t.slice, &f.slice) { + (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { + let func = dev.get_or_load_func("where_f32", kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { + // SAFETY: Set later by running the kernel. + let func = dev.get_or_load_func("where_f64", kernels::BINARY)?; + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::F64(out) + } + (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { + // SAFETY: Set later by running the kernel. + let func = dev.get_or_load_func("where_u32", kernels::BINARY)?; + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U32(out) + } + // The dtypes should have been checked at this point so this is an internal error. + _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), + }; + let device = dev.clone(); + Ok(Self { slice, device }) } pub(crate) fn embedding_impl(