Add the elu op. (#113)

This commit is contained in:
Laurent Mazare
2023-07-09 21:56:31 +01:00
committed by GitHub
parent ea5dfa69bc
commit 270997a055
8 changed files with 95 additions and 0 deletions

View File

@ -357,6 +357,30 @@ impl Map1 for Affine {
}
}
struct Elu(f64);
impl Map1 for Elu {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
#[allow(dead_code)]
struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> {
@ -815,6 +839,12 @@ impl CudaStorage {
Ok(Self { slice, device })
}
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device().clone();
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;