diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a44f732f..46a2c8de 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -76,6 +76,7 @@ impl Tensor { | Op::Sqrt(node) | Op::Gelu(node) | Op::Relu(node) + | Op::Elu(node, _) | Op::Exp(node) | Op::Log(node) | Op::Sin(node) @@ -250,6 +251,7 @@ impl Tensor { } Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }), + Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }), Op::Sqr(arg) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 6d1cea3b..15982040 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -463,6 +463,14 @@ fn divide_by_sum_over_dim(s: &mut [T], shape: &Shape, dim: usize) Ok(()) } +fn elu(v: T, alpha: T) -> T { + if v.is_sign_positive() { + v + } else { + (v.exp() - T::one()) * alpha + } +} + impl CpuStorage { pub fn dtype(&self) -> DType { match self { @@ -666,6 +674,30 @@ impl CpuStorage { Affine(mul, add).map(self, layout) } + pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result { + // TODO: Have some generic map for functions that apply on num_traits::Float elements. + match self { + Self::BF16(storage) => { + let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha))); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha))); + Ok(Self::F16(data)) + } + Self::F32(storage) => { + let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha))); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let data = unary_map(storage, layout, |v| elu(v, alpha)); + Ok(Self::F64(data)) + } + Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu")), + Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu")), + } + } + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { match self { Self::BF16(storage) => { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 3df34cee..3df1b33c 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -357,6 +357,30 @@ impl Map1 for Affine { } } +struct Elu(f64); +impl Map1 for Elu { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + #[allow(dead_code)] struct Sum<'a>(&'a [usize]); impl<'a> Map1 for Sum<'a> { @@ -815,6 +839,12 @@ impl CudaStorage { Ok(Self { slice, device }) } + pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result { + let device = self.device().clone(); + let slice = Elu(alpha).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { let device = self.device().clone(); let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 0dbd8d54..9263adee 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -64,6 +64,10 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn elu(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 9f12b9a2..d8f3b4b4 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -127,6 +127,9 @@ pub enum Error { #[error("unsupported safetensor dtype {0:?}")] UnsupportedSafeTensorDtype(safetensors::Dtype), + #[error("unsupported dtype {0:?} for op {1}")] + UnsupportedDTypeForOp(DType, &'static str), + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index ee57b325..1b2d800d 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -46,6 +46,7 @@ pub(crate) enum Op { Transpose(Tensor, usize, usize), Gelu(Tensor), Relu(Tensor), + Elu(Tensor, f64), // TODO: Support for custom ops. } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 53ea1544..ee12eeb8 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -66,6 +66,19 @@ impl Storage { } } + pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.elu(layout, alpha)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.elu(layout, alpha)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6f8dc378..f0ce18f9 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -349,6 +349,16 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + pub fn elu(&self, alpha: f64) -> Result { + let storage = self.storage.elu(self.layout(), alpha)?; + let op = if self.track_op() { + Some(Op::Elu(self.clone(), alpha)) + } else { + None + }; + Ok(from_storage(storage, self.shape(), op, false)) + } + fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { if dim >= self.dims().len() { Err(Error::DimOutOfRange {