mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the elu op. (#113)
This commit is contained in:
@ -357,6 +357,30 @@ impl Map1 for Affine {
|
||||
}
|
||||
}
|
||||
|
||||
struct Elu(f64);
|
||||
impl Map1 for Elu {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el) }?;
|
||||
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct Sum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
@ -815,6 +839,12 @@ impl CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
|
||||
|
Reference in New Issue
Block a user