diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7d06dd72..0e9c11c8 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -244,6 +244,7 @@ enum CudaStorageSlice { F32(CudaSlice), F64(CudaSlice), } +type S = CudaStorageSlice; trait Map1 { fn f( @@ -253,13 +254,36 @@ trait Map1 { layout: &Layout, ) -> Result>; - fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result { + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { let out = match s { - CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?), - CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?), - CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?), - CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?), - CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?), + S::U32(s) => S::U32(self.f(s, d, l)?), + S::BF16(s) => S::BF16(self.f(s, d, l)?), + S::F16(s) => S::F16(self.f(s, d, l)?), + S::F32(s) => S::F32(self.f(s, d, l)?), + S::F64(s) => S::F64(self.f(s, d, l)?), + }; + Ok(out) + } +} + +trait Map2 { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result>; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), + (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), + (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), + (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), }; Ok(out) } @@ -411,6 +435,44 @@ impl<'a> Map1 for Embedding<'a> { } } +struct WhereCond<'a>(&'a CudaStorage, &'a Layout); +impl<'a> Map2 for WhereCond<'a> { + fn f( + &self, + t: &CudaSlice, + layout_t: &Layout, + f: &CudaSlice, + layout_f: &Layout, + dev: &CudaDevice, + ) -> Result> { + let ids_l = &self.1; + let ids = match &self.0.slice { + CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + _ => Err(CudaError::UnexpectedDType { + msg: "where conditions should be u32", + expected: DType::U32, + got: self.0.dtype(), + })?, + }; + let ids = &ids; + let shape = ids_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = + dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?; + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("where"), kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -714,86 +776,12 @@ impl CudaStorage { &self, layout: &Layout, t: &Self, - layout_t: &Layout, + t_l: &Layout, f: &Self, - layout_f: &Layout, + f_l: &Layout, ) -> Result { - let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), - _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u32", - expected: DType::U32, - got: self.dtype(), - })?, - }; - let ids = &ids; - let shape = layout.shape(); - let dims = shape.dims(); - let el = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = - dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?; - let slice = match (&t.slice, &f.slice) { - (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?; - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?; - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - // The dtypes should have been checked at this point so this is an internal error. - _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?; Ok(Self { slice, device }) }