From 8ad03a5fb674973064c4a2679140344e44f6d737 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:37:38 +0100 Subject: [PATCH] Use Map1 on unary ops. --- candle-core/src/cuda_backend.rs | 77 +++++++++++---------------------- candle-core/src/op.rs | 28 ++++-------- 2 files changed, 33 insertions(+), 72 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 6add6eb7..dc0d51bf 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -350,6 +350,29 @@ impl<'a> Map1 for Sum<'a> { } } +impl Map1 for U { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -539,58 +562,8 @@ impl CudaStorage { } pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { - let shape = layout.shape(); - let dims = shape.dims(); - let el_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el_count as u32); - let dev = &self.device; - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(_arg) => { - todo!("No unary kernels for u32"); - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = U::V.map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 7b0e18fe..bbbd4bac 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -43,11 +43,8 @@ pub(crate) enum Op { pub(crate) trait UnaryOp { const NAME: &'static str; - const KERNEL_BF16: &'static str; - const KERNEL_F16: &'static str; - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - const KERNEL_U32: &'static str; + const KERNEL: &'static str; + const V: Self; fn bf16(v1: bf16) -> bf16; fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; @@ -121,11 +118,8 @@ macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { impl UnaryOp for $op { const NAME: &'static str = $name; - const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16"); - const KERNEL_F16: &'static str = concat!("u", $name, "_f16"); - const KERNEL_F32: &'static str = concat!("u", $name, "_f32"); - const KERNEL_F64: &'static str = concat!("u", $name, "_f64"); - const KERNEL_U32: &'static str = concat!("u", $name, "_u32"); + const KERNEL: &'static str = concat!("u", $name); + const V: Self = $op; fn bf16($a: bf16) -> bf16 { $e } @@ -158,6 +152,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt()); /// impl UnaryOp for Gelu { const NAME: &'static str = "gelu"; + const V: Self = Gelu; fn bf16(v: bf16) -> bf16 { bf16::from_f32_const(0.5) * v @@ -191,20 +186,13 @@ impl UnaryOp for Gelu { fn u32(_: u32) -> u32 { 0 } - const KERNEL_BF16: &'static str = "ugelu_bf16"; - const KERNEL_F16: &'static str = "ugelu_f16"; - const KERNEL_F32: &'static str = "ugelu_f32"; - const KERNEL_F64: &'static str = "ugelu_f64"; - const KERNEL_U32: &'static str = "ugelu_u32"; + const KERNEL: &'static str = "ugelu"; } impl UnaryOp for Relu { const NAME: &'static str = "relu"; - const KERNEL_BF16: &'static str = "urelu_bf16"; - const KERNEL_F16: &'static str = "urelu_f16"; - const KERNEL_F32: &'static str = "urelu_f32"; - const KERNEL_F64: &'static str = "urelu_f64"; - const KERNEL_U32: &'static str = "urelu_u32"; + const KERNEL: &'static str = "urelu"; + const V: Self = Relu; fn bf16(v: bf16) -> bf16 { v.max(bf16::ZERO) }