diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index fd2ac94c..67a08714 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -15,6 +15,8 @@ pub trait BackendStorage: Sized { fn affine(&self, _: &Layout, _: f64, _: f64) -> Result; + fn powf(&self, _: &Layout, _: f64) -> Result; + fn elu(&self, _: &Layout, _: f64) -> Result; fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index c6d55e61..adb3e1dd 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -105,6 +105,7 @@ impl Tensor { | Op::Narrow(node, _, _, _) | Op::Unary(node, _) | Op::Elu(node, _) + | Op::Powf(node, _) | Op::CustomOp1(node, _) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; @@ -437,6 +438,11 @@ impl Tensor { *sum_grad = sum_grad.add(&(&grad * relu_grad)?)? } Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?, + Op::Powf(arg, e) => { + let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } Op::CustomOp1(arg, c) => { if let Some(arg_grad) = c.bwd(arg, node, &grad)? { let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index d4615b0a..4a061c39 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1929,6 +1929,32 @@ impl BackendStorage for CpuStorage { UpsampleNearest2D(h, w).map(self, layout) } + fn powf(&self, layout: &Layout, e: f64) -> Result { + use num_traits::Float; + // TODO: Have some generic map for functions that apply on num_traits::Float elements. + match self { + Self::BF16(storage) => { + let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e))); + Ok(Self::BF16(data)) + } + Self::F16(storage) => { + let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e))); + Ok(Self::F16(data)) + } + Self::F32(storage) => { + let data = unary_map(storage, layout, |v| v.powf(e as f32)); + Ok(Self::F32(data)) + } + Self::F64(storage) => { + let data = unary_map(storage, layout, |v| v.powf(e)); + Ok(Self::F64(data)) + } + Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), + Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), + } + } + fn elu(&self, layout: &Layout, alpha: f64) -> Result { // TODO: Have some generic map for functions that apply on num_traits::Float elements. match self { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index cd06e8d7..14a77b52 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -593,6 +593,30 @@ impl Map1 for Elu { } } +struct Powf(f64); +impl Map1 for Powf { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("upowf"), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + struct Sum<'a>(&'a [usize]); impl<'a> Map1 for Sum<'a> { fn f( @@ -1531,6 +1555,12 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn powf(&self, layout: &Layout, e: f64) -> Result { + let device = self.device().clone(); + let slice = Powf(e).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + fn elu(&self, layout: &Layout, alpha: f64) -> Result { let device = self.device().clone(); let slice = Elu(alpha).map(&self.slice, &device, layout)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 56eda7f3..6c896653 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn powf(&self, _: &Layout, _: f64) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + fn elu(&self, _: &Layout, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 3fe52ebc..213ae2c8 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -134,6 +134,7 @@ pub enum Op { Transpose(Tensor, usize, usize), Permute(Tensor, Vec), Elu(Tensor, f64), + Powf(Tensor, f64), CustomOp1(Tensor, std::sync::Arc>), CustomOp2( Tensor, diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index d47ab292..8bd14ea9 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -68,6 +68,19 @@ impl Storage { } } + pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result { + match self { + Storage::Cpu(storage) => { + let storage = storage.powf(layout, alpha)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.powf(layout, alpha)?; + Ok(Self::Cuda(storage)) + } + } + } + pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f834e040..12e98029 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -535,6 +535,13 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + /// Raise the tensor to some float exponent `e`. + pub fn powf(&self, e: f64) -> Result { + let storage = self.storage().powf(self.layout(), e)?; + let op = BackpropOp::new1(self, |t| Op::Powf(t, e)); + Ok(from_storage(storage, self.shape(), op, false)) + } + fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { if dim >= self.dims().len() { Err(Error::DimOutOfRange { diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index 81b1338a..26e05b68 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -173,6 +173,16 @@ fn unary_grad(device: &Device) -> Result<()> { let grad_x = grads.get(x).context("no grad for x")?; assert_eq!(y.to_vec1::()?, [6., 2., 8., 0.3]); assert_eq!(grad_x.to_vec1::()?, [2., 2., 2., 2.]); + + let x = Var::new(&[3f32, 1., 4., 0.15], device)?; + let y = x.powf(2.5)?; + let grads = y.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]); + assert_eq!( + test_utils::to_vec1_round(grad_x, 2)?, + [12.99, 2.5, 20.0, 0.15] + ); Ok(()) } diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index 85d74b82..c5b18461 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -91,6 +91,7 @@ UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x)) UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) +UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) #endif #if __CUDA_ARCH__ >= 530 @@ -107,6 +108,7 @@ UNARY_OP(__half, usqrt_f16, sqrtg(x)) UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) +UNARY_OP1(__half, upowf_f16, powg(x, param)) #endif UNARY_OP(uint8_t, ucopy_u8, x) @@ -137,3 +139,5 @@ UNARY_OP(float, urelu_f32, relu_fwd(x)) UNARY_OP(double, urelu_f64, relu_fwd(x)) UNARY_OP1(float, uelu_f32, elu_fwd(x, param)) UNARY_OP1(double, uelu_f64, elu_fwd(x, param)) +UNARY_OP1(float, upowf_f32, powg(x, param)) +UNARY_OP1(double, upowf_f64, powg(x, param))