Add the powf op. (#664)

* Add the powf op.

* Cuda kernels and backprop.

* Add a test.
This commit is contained in:
Laurent Mazare
2023-08-29 20:48:18 +01:00
committed by GitHub
parent 2d3fcad267
commit 59b731de99
10 changed files with 103 additions and 0 deletions

View File

@ -15,6 +15,8 @@ pub trait BackendStorage: Sized {
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>; fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self>;
fn powf(&self, _: &Layout, _: f64) -> Result<Self>;
fn elu(&self, _: &Layout, _: f64) -> Result<Self>; fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>; fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;

View File

@ -105,6 +105,7 @@ impl Tensor {
| Op::Narrow(node, _, _, _) | Op::Narrow(node, _, _, _)
| Op::Unary(node, _) | Op::Unary(node, _)
| Op::Elu(node, _) | Op::Elu(node, _)
| Op::Powf(node, _)
| Op::CustomOp1(node, _) => { | Op::CustomOp1(node, _) => {
let (tg, nodes) = walk(node, nodes, already_seen); let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg; track_grad |= tg;
@ -437,6 +438,11 @@ impl Tensor {
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)? *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
} }
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?, Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::Powf(arg, e) => {
let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::CustomOp1(arg, c) => { Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? { if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;

View File

@ -1929,6 +1929,32 @@ impl BackendStorage for CpuStorage {
UpsampleNearest2D(h, w).map(self, layout) UpsampleNearest2D(h, w).map(self, layout)
} }
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
use num_traits::Float;
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
Self::BF16(storage) => {
let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
Ok(Self::F16(data))
}
Self::F32(storage) => {
let data = unary_map(storage, layout, |v| v.powf(e as f32));
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(storage, layout, |v| v.powf(e));
Ok(Self::F64(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements. // TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self { match self {

View File

@ -593,6 +593,30 @@ impl Map1 for Elu {
} }
} }
struct Powf(f64);
impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct Sum<'a>(&'a [usize]); struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> { impl<'a> Map1 for Sum<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1531,6 +1555,12 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device }) Ok(Self { slice, device })
} }
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Powf(e).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let device = self.device().clone(); let device = self.device().clone();
let slice = Elu(alpha).map(&self.slice, &device, layout)?; let slice = Elu(alpha).map(&self.slice, &device, layout)?;

View File

@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> { fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -134,6 +134,7 @@ pub enum Op {
Transpose(Tensor, usize, usize), Transpose(Tensor, usize, usize),
Permute(Tensor, Vec<usize>), Permute(Tensor, Vec<usize>),
Elu(Tensor, f64), Elu(Tensor, f64),
Powf(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>), CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
CustomOp2( CustomOp2(
Tensor, Tensor,

View File

@ -68,6 +68,19 @@ impl Storage {
} }
} }
pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> { pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
match self { match self {
Storage::Cpu(storage) => { Storage::Cpu(storage) => {

View File

@ -535,6 +535,13 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
/// Raise the tensor to some float exponent `e`.
pub fn powf(&self, e: f64) -> Result<Self> {
let storage = self.storage().powf(self.layout(), e)?;
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
Ok(from_storage(storage, self.shape(), op, false))
}
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() { if dim >= self.dims().len() {
Err(Error::DimOutOfRange { Err(Error::DimOutOfRange {

View File

@ -173,6 +173,16 @@ fn unary_grad(device: &Device) -> Result<()> {
let grad_x = grads.get(x).context("no grad for x")?; let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]); assert_eq!(y.to_vec1::<f32>()?, [6., 2., 8., 0.3]);
assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]); assert_eq!(grad_x.to_vec1::<f32>()?, [2., 2., 2., 2.]);
let x = Var::new(&[3f32, 1., 4., 0.15], device)?;
let y = x.powf(2.5)?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(test_utils::to_vec1_round(&y, 2)?, [15.59, 1.0, 32.0, 0.01]);
assert_eq!(
test_utils::to_vec1_round(grad_x, 2)?,
[12.99, 2.5, 20.0, 0.15]
);
Ok(()) Ok(())
} }

View File

@ -91,6 +91,7 @@ UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x)) UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x)) UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param)) UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
#endif #endif
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
@ -107,6 +108,7 @@ UNARY_OP(__half, usqrt_f16, sqrtg(x))
UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param)) UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP1(__half, upowf_f16, powg(x, param))
#endif #endif
UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint8_t, ucopy_u8, x)
@ -137,3 +139,5 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x)) UNARY_OP(double, urelu_f64, relu_fwd(x))
UNARY_OP1(float, uelu_f32, elu_fwd(x, param)) UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
UNARY_OP1(double, uelu_f64, elu_fwd(x, param)) UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
UNARY_OP1(float, upowf_f32, powg(x, param))
UNARY_OP1(double, upowf_f64, powg(x, param))