diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 0e9c11c8..40b7e67f 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -473,6 +473,32 @@ impl<'a> Map2 for WhereCond<'a> { } } +impl Map2 for U { + fn f( + &self, + lhs: &CudaSlice, + lhs_l: &Layout, + rhs: &CudaSlice, + rhs_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + let shape = lhs_l.shape(); + let dims = shape.dims(); + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -673,72 +699,8 @@ impl CudaStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let shape = lhs_l.shape(); - let dims = shape.dims(); - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dev = self.device(); - let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; - let slice = match (&self.slice, &rhs.slice) { - (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?; - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - // The dtypes should have been checked at this point so this is an internal error. - _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; Ok(Self { slice, device }) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index bbbd4bac..db6ef87f 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -54,11 +54,8 @@ pub(crate) trait UnaryOp { pub(crate) trait BinaryOp { const NAME: &'static str; - const KERNEL_BF16: &'static str; - const KERNEL_F16: &'static str; - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - const KERNEL_U32: &'static str; + const KERNEL: &'static str; + const V: Self; fn bf16(v1: bf16, v2: bf16) -> bf16; fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; @@ -85,11 +82,8 @@ macro_rules! bin_op { ($op:ident, $name: literal, $e: expr) => { impl BinaryOp for $op { const NAME: &'static str = $name; - const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16"); - const KERNEL_F16: &'static str = concat!("b", $name, "_f16"); - const KERNEL_F32: &'static str = concat!("b", $name, "_f32"); - const KERNEL_F64: &'static str = concat!("b", $name, "_f64"); - const KERNEL_U32: &'static str = concat!("b", $name, "_u32"); + const KERNEL: &'static str = concat!("b", $name); + const V: Self = $op; fn bf16(v1: bf16, v2: bf16) -> bf16 { $e(v1, v2) }